diff options
Diffstat (limited to 'test')
74 files changed, 2701 insertions, 885 deletions
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD index fb0b2db41..dfcd55f60 100644 --- a/test/packetdrill/BUILD +++ b/test/packetdrill/BUILD @@ -1,4 +1,4 @@ -load("defs.bzl", "packetdrill_linux_test", "packetdrill_netstack_test", "packetdrill_test") +load("defs.bzl", "packetdrill_test") package(licenses = ["notice"]) @@ -17,16 +17,6 @@ packetdrill_test( scripts = ["fin_wait2_timeout.pkt"], ) -packetdrill_linux_test( - name = "tcp_user_timeout_test_linux_test", - scripts = ["linux/tcp_user_timeout.pkt"], -) - -packetdrill_netstack_test( - name = "tcp_user_timeout_test_netstack_test", - scripts = ["netstack/tcp_user_timeout.pkt"], -) - packetdrill_test( name = "listen_close_before_handshake_complete_test", scripts = ["listen_close_before_handshake_complete.pkt"], diff --git a/test/packetdrill/linux/tcp_user_timeout.pkt b/test/packetdrill/linux/tcp_user_timeout.pkt deleted file mode 100644 index 38018cb42..000000000 --- a/test/packetdrill/linux/tcp_user_timeout.pkt +++ /dev/null @@ -1,39 +0,0 @@ -// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection -// if there is pending unacked data after the user specified timeout. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0 > S. 0:0(0) ack 1 <...> -+0.1 < . 1:1(0) ack 1 win 32792 - -+0.100 accept(3, ..., ...) = 4 - -// Okay, we received nothing, and decide to close this idle socket. -// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth -// trying hard to cleanly close this flow, at the price of keeping -// a TCP structure in kernel for about 1 minute! -+2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0 - -// The write/ack is required mainly for netstack as netstack does -// not update its RTO during the handshake. -+0 write(4, ..., 100) = 100 -+0 > P. 1:101(100) ack 1 <...> -+0 < . 1:1(0) ack 101 win 32792 - -+0 close(4) = 0 - -+0 > F. 101:101(0) ack 1 <...> -+.3~+.400 > F. 101:101(0) ack 1 <...> -+.3~+.400 > F. 101:101(0) ack 1 <...> -+.6~+.800 > F. 101:101(0) ack 1 <...> -+1.2~+1.300 > F. 101:101(0) ack 1 <...> - -// We finally receive something from the peer, but it is way too late -// Our socket vanished because TCP_USER_TIMEOUT was really small. -+.1 < . 1:2(1) ack 102 win 32792 -+0 > R 102:102(0) win 0 diff --git a/test/packetdrill/netstack/tcp_user_timeout.pkt b/test/packetdrill/netstack/tcp_user_timeout.pkt deleted file mode 100644 index 60103adba..000000000 --- a/test/packetdrill/netstack/tcp_user_timeout.pkt +++ /dev/null @@ -1,38 +0,0 @@ -// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection -// if there is pending unacked data after the user specified timeout. - -0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 -+0 bind(3, ..., ...) = 0 - -+0 listen(3, 1) = 0 - -// Establish a connection without timestamps. -+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> -+0 > S. 0:0(0) ack 1 <...> -+0.1 < . 1:1(0) ack 1 win 32792 - -+0.100 accept(3, ..., ...) = 4 - -// Okay, we received nothing, and decide to close this idle socket. -// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth -// trying hard to cleanly close this flow, at the price of keeping -// a TCP structure in kernel for about 1 minute! -+2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0 - -// The write/ack is required mainly for netstack as netstack does -// not update its RTO during the handshake. -+0 write(4, ..., 100) = 100 -+0 > P. 1:101(100) ack 1 <...> -+0 < . 1:1(0) ack 101 win 32792 - -+0 close(4) = 0 - -+0 > F. 101:101(0) ack 1 <...> -+.2~+.300 > F. 101:101(0) ack 1 <...> -+.4~+.500 > F. 101:101(0) ack 1 <...> -+.8~+.900 > F. 101:101(0) ack 1 <...> - -// We finally receive something from the peer, but it is way too late -// Our socket vanished because TCP_USER_TIMEOUT was really small. -+1.61 < . 1:2(1) ack 102 win 32792 -+0 > R 102:102(0) win 0 diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index 2f10dda40..86e580c6f 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -61,13 +61,15 @@ } class PosixImpl final : public posix_server::Posix::Service { - ::grpc::Status Socket(grpc_impl::ServerContext *context, - const ::posix_server::SocketRequest *request, - ::posix_server::SocketResponse *response) override { - response->set_fd( - socket(request->domain(), request->type(), request->protocol())); + ::grpc::Status Accept(grpc_impl::ServerContext *context, + const ::posix_server::AcceptRequest *request, + ::posix_server::AcceptResponse *response) override { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + response->set_fd(accept(request->sockfd(), + reinterpret_cast<sockaddr *>(&addr), &addrlen)); response->set_errno_(errno); - return ::grpc::Status::OK; + return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); } ::grpc::Status Bind(grpc_impl::ServerContext *context, @@ -119,6 +121,14 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status::OK; } + ::grpc::Status Close(grpc_impl::ServerContext *context, + const ::posix_server::CloseRequest *request, + ::posix_server::CloseResponse *response) override { + response->set_ret(close(request->fd())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + ::grpc::Status GetSockName( grpc_impl::ServerContext *context, const ::posix_server::GetSockNameRequest *request, @@ -139,15 +149,13 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status::OK; } - ::grpc::Status Accept(grpc_impl::ServerContext *context, - const ::posix_server::AcceptRequest *request, - ::posix_server::AcceptResponse *response) override { - sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - response->set_fd(accept(request->sockfd(), - reinterpret_cast<sockaddr *>(&addr), &addrlen)); + ::grpc::Status Send(::grpc::ServerContext *context, + const ::posix_server::SendRequest *request, + ::posix_server::SendResponse *response) override { + response->set_ret(::send(request->sockfd(), request->buf().data(), + request->buf().size(), request->flags())); response->set_errno_(errno); - return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); + return ::grpc::Status::OK; } ::grpc::Status SetSockOpt( @@ -161,6 +169,17 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status::OK; } + ::grpc::Status SetSockOptInt( + ::grpc::ServerContext *context, + const ::posix_server::SetSockOptIntRequest *request, + ::posix_server::SetSockOptIntResponse *response) override { + int opt = request->intval(); + response->set_ret(::setsockopt(request->sockfd(), request->level(), + request->optname(), &opt, sizeof(opt))); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + ::grpc::Status SetSockOptTimeval( ::grpc::ServerContext *context, const ::posix_server::SetSockOptTimevalRequest *request, @@ -174,11 +193,23 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status::OK; } - ::grpc::Status Close(grpc_impl::ServerContext *context, - const ::posix_server::CloseRequest *request, - ::posix_server::CloseResponse *response) override { - response->set_ret(close(request->fd())); + ::grpc::Status Socket(grpc_impl::ServerContext *context, + const ::posix_server::SocketRequest *request, + ::posix_server::SocketResponse *response) override { + response->set_fd( + socket(request->domain(), request->type(), request->protocol())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Recv(::grpc::ServerContext *context, + const ::posix_server::RecvRequest *request, + ::posix_server::RecvResponse *response) override { + std::vector<char> buf(request->len()); + response->set_ret( + recv(request->sockfd(), buf.data(), buf.size(), request->flags())); response->set_errno_(errno); + response->set_buf(buf.data(), response->ret()); return ::grpc::Status::OK; } }; diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index 026876fc2..4035e1ee6 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -16,17 +16,6 @@ syntax = "proto3"; package posix_server; -message SocketRequest { - int32 domain = 1; - int32 type = 2; - int32 protocol = 3; -} - -message SocketResponse { - int32 fd = 1; - int32 errno_ = 2; -} - message SockaddrIn { int32 family = 1; uint32 port = 2; @@ -48,6 +37,23 @@ message Sockaddr { } } +message Timeval { + int64 seconds = 1; + int64 microseconds = 2; +} + +// Request and Response pairs for each Posix service RPC call, sorted. + +message AcceptRequest { + int32 sockfd = 1; +} + +message AcceptResponse { + int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + Sockaddr addr = 3; +} + message BindRequest { int32 sockfd = 1; Sockaddr addr = 2; @@ -55,7 +61,16 @@ message BindRequest { message BindResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message CloseRequest { + int32 fd = 1; +} + +message CloseResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. } message GetSockNameRequest { @@ -64,7 +79,7 @@ message GetSockNameRequest { message GetSockNameResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. Sockaddr addr = 3; } @@ -75,17 +90,18 @@ message ListenRequest { message ListenResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } -message AcceptRequest { +message SendRequest { int32 sockfd = 1; + bytes buf = 2; + int32 flags = 3; } -message AcceptResponse { - int32 fd = 1; +message SendResponse { + int32 ret = 1; int32 errno_ = 2; - Sockaddr addr = 3; } message SetSockOptRequest { @@ -97,12 +113,19 @@ message SetSockOptRequest { message SetSockOptResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } -message Timeval { - int64 seconds = 1; - int64 microseconds = 2; +message SetSockOptIntRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + int32 intval = 4; +} + +message SetSockOptIntResponse { + int32 ret = 1; + int32 errno_ = 2; } message SetSockOptTimevalRequest { @@ -114,37 +137,57 @@ message SetSockOptTimevalRequest { message SetSockOptTimevalResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. } -message CloseRequest { +message SocketRequest { + int32 domain = 1; + int32 type = 2; + int32 protocol = 3; +} + +message SocketResponse { int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. } -message CloseResponse { +message RecvRequest { + int32 sockfd = 1; + int32 len = 2; + int32 flags = 3; +} + +message RecvResponse { int32 ret = 1; - int32 errno_ = 2; + int32 errno_ = 2; // "errno" may fail to compile in c++. + bytes buf = 3; } service Posix { - // Call socket() on the DUT. - rpc Socket(SocketRequest) returns (SocketResponse); + // Call accept() on the DUT. + rpc Accept(AcceptRequest) returns (AcceptResponse); // Call bind() on the DUT. rpc Bind(BindRequest) returns (BindResponse); + // Call close() on the DUT. + rpc Close(CloseRequest) returns (CloseResponse); // Call getsockname() on the DUT. rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse); // Call listen() on the DUT. rpc Listen(ListenRequest) returns (ListenResponse); - // Call accept() on the DUT. - rpc Accept(AcceptRequest) returns (AcceptResponse); + // Call send() on the DUT. + rpc Send(SendRequest) returns (SendResponse); // Call setsockopt() on the DUT. You should prefer one of the other // SetSockOpt* functions with a more structured optval or else you may get the // encoding wrong, such as making a bad assumption about the server's word // sizes or endianness. rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse); + // Call setsockopt() on the DUT with an int optval. + rpc SetSockOptInt(SetSockOptIntRequest) returns (SetSockOptIntResponse); // Call setsockopt() on the DUT with a Timeval optval. rpc SetSockOptTimeval(SetSockOptTimevalRequest) returns (SetSockOptTimevalResponse); - // Call close() on the DUT. - rpc Close(CloseRequest) returns (CloseResponse); + // Call socket() on the DUT. + rpc Socket(SocketRequest) returns (SocketResponse); + // Call recv() on the DUT. + rpc Recv(RecvRequest) returns (RecvResponse); } diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index a34c81fcc..b6a254882 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package( default_visibility = ["//test/packetimpact:__subpackages__"], @@ -16,6 +16,7 @@ go_library( ], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", "//pkg/usermem", @@ -27,5 +28,14 @@ go_library( "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//keepalive:go_default_library", "@org_golang_x_sys//unix:go_default_library", + "@org_uber_go_multierr//:go_default_library", ], ) + +go_test( + name = "testbench_test", + size = "small", + srcs = ["layers_test.go"], + library = ":testbench", + deps = ["//pkg/tcpip"], +) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index b7aa63934..f84fd8ba7 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -21,10 +21,12 @@ import ( "fmt" "math/rand" "net" + "strings" "testing" "time" "github.com/mohae/deepcopy" + "go.uber.org/multierr" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -36,19 +38,6 @@ var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test pa var localMAC = flag.String("local_mac", "", "local mac address for test packets") var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets") -// TCPIPv4 maintains state about a TCP/IPv4 connection. -type TCPIPv4 struct { - outgoing Layers - incoming Layers - LocalSeqNum seqnum.Value - RemoteSeqNum seqnum.Value - SynAck *TCP - sniffer Sniffer - injector Injector - portPickerFD int - t *testing.T -} - // pickPort makes a new socket and returns the socket FD and port. The caller // must close the FD when done with the port if there is no error. func pickPort() (int, uint16, error) { @@ -75,171 +64,607 @@ func pickPort() (int, uint16, error) { return fd, uint16(newSockAddrInet4.Port), nil } -// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It -// is the third, after Ethernet and IPv4. -const tcpLayerIndex int = 2 +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. + outgoing() Layer -// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. -func NewTCPIPv4(t *testing.T, dut DUT, outgoingTCP, incomingTCP TCP) TCPIPv4 { + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error +} + +// etherState maintains state about an Ethernet connection. +type etherState struct { + out, in Ether +} + +var _ layerState = (*etherState)(nil) + +// newEtherState creates a new etherState. +func newEtherState(out, in Ether) (*etherState, error) { lMAC, err := tcpip.ParseMACAddress(*localMAC) if err != nil { - t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) + return nil, err } rMAC, err := tcpip.ParseMACAddress(*remoteMAC) if err != nil { - t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err) + return nil, err } - - portPickerFD, localPort, err := pickPort() - if err != nil { - t.Fatalf("can't pick a port: %s", err) + s := etherState{ + out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, + in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, + } + if err := s.out.merge(&out); err != nil { + return nil, err } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *etherState) outgoing() Layer { + return &s.out +} + +func (s *etherState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*etherState) sent(Layer) error { + return nil +} + +func (*etherState) received(Layer) error { + return nil +} + +func (*etherState) close() error { + return nil +} + +// ipv4State maintains state about an IPv4 connection. +type ipv4State struct { + out, in IPv4 +} + +var _ layerState = (*ipv4State)(nil) + +// newIPv4State creates a new ipv4State. +func newIPv4State(out, in IPv4) (*ipv4State, error) { lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) + s := ipv4State{ + out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} - sniffer, err := NewSniffer(t) +func (s *ipv4State) outgoing() Layer { + return &s.out +} + +func (s *ipv4State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*ipv4State) sent(Layer) error { + return nil +} + +func (*ipv4State) received(Layer) error { + return nil +} + +func (*ipv4State) close() error { + return nil +} + +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} + +var _ layerState = (*tcpState)(nil) + +// SeqNumValue is a helper routine that allocates a new seqnum.Value value to +// store v and returns a pointer to it. +func SeqNumValue(v seqnum.Value) *seqnum.Value { + return &v +} + +// newTCPState creates a new TCPState. +func newTCPState(out, in TCP) (*tcpState, error) { + portPickerFD, localPort, err := pickPort() if err != nil { - t.Fatalf("can't make new sniffer: %s", err) + return nil, err + } + s := tcpState{ + out: TCP{SrcPort: &localPort}, + in: TCP{DstPort: &localPort}, + localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), + portPickerFD: portPickerFD, + finSent: false, + } + if err := s.out.merge(&out); err != nil { + return nil, err } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} - injector, err := NewInjector(t) - if err != nil { - t.Fatalf("can't make new injector: %s", err) +func (s *tcpState) outgoing() Layer { + newOutgoing := deepcopy.Copy(s.out).(TCP) + if s.localSeqNum != nil { + newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) + } + if s.remoteSeqNum != nil { + newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) + } + return &newOutgoing +} + +func (s *tcpState) incoming(received Layer) Layer { + tcpReceived, ok := received.(*TCP) + if !ok { + return nil + } + newIn := deepcopy.Copy(s.in).(TCP) + if s.remoteSeqNum != nil { + newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) + } + if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 { + // The caller didn't specify an AckNum so we'll expect the calculated one, + // but only if the ACK flag is set because the AckNum is not valid in a + // header if ACK is not set. + newIn.AckNum = Uint32(uint32(*s.localSeqNum)) } + return &newIn +} - newOutgoingTCP := &TCP{ - DataOffset: Uint8(header.TCPMinimumSize), - WindowSize: Uint16(32768), - SrcPort: &localPort, +func (s *tcpState) sent(sent Layer) error { + tcp, ok := sent.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", sent) + } + if !s.finSent { + // update localSeqNum by the payload only when FIN is not yet sent by us + for current := tcp.next(); current != nil; current = current.next() { + s.localSeqNum.UpdateForward(seqnum.Size(current.length())) + } } - if err := newOutgoingTCP.merge(outgoingTCP); err != nil { - t.Fatalf("can't merge %v into %v: %s", outgoingTCP, newOutgoingTCP, err) + if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.localSeqNum.UpdateForward(1) } - newIncomingTCP := &TCP{ - DstPort: &localPort, + if *tcp.Flags&(header.TCPFlagFin) != 0 { + s.finSent = true } - if err := newIncomingTCP.merge(incomingTCP); err != nil { - t.Fatalf("can't merge %v into %v: %s", incomingTCP, newIncomingTCP, err) + return nil +} + +func (s *tcpState) received(l Layer) error { + tcp, ok := l.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", l) } - return TCPIPv4{ - outgoing: Layers{ - &Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, - &IPv4{SrcAddr: &lIP, DstAddr: &rIP}, - newOutgoingTCP}, - incoming: Layers{ - &Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, - &IPv4{SrcAddr: &rIP, DstAddr: &lIP}, - newIncomingTCP}, - sniffer: sniffer, - injector: injector, + s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) + if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.remoteSeqNum.UpdateForward(1) + } + for current := tcp.next(); current != nil; current = current.next() { + s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) + } + return nil +} + +// close frees the port associated with this connection. +func (s *tcpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err + } + s.portPickerFD = -1 + return nil +} + +// udpState maintains state about a UDP connection. +type udpState struct { + out, in UDP + portPickerFD int +} + +var _ layerState = (*udpState)(nil) + +// newUDPState creates a new udpState. +func newUDPState(out, in UDP) (*udpState, error) { + portPickerFD, localPort, err := pickPort() + if err != nil { + return nil, err + } + s := udpState{ + out: UDP{SrcPort: &localPort}, + in: UDP{DstPort: &localPort}, portPickerFD: portPickerFD, - t: t, - LocalSeqNum: seqnum.Value(rand.Uint32()), } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil } -// Close the injector and sniffer associated with this connection. -func (conn *TCPIPv4) Close() { - conn.sniffer.Close() - conn.injector.Close() - if err := unix.Close(conn.portPickerFD); err != nil { - conn.t.Fatalf("can't close portPickerFD: %s", err) +func (s *udpState) outgoing() Layer { + return &s.out +} + +func (s *udpState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*udpState) sent(l Layer) error { + return nil +} + +func (*udpState) received(l Layer) error { + return nil +} + +// close frees the port associated with this connection. +func (s *udpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err } - conn.portPickerFD = -1 + s.portPickerFD = -1 + return nil } -// Send a packet with reasonable defaults and override some fields by tcp. -func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { - if tcp.SeqNum == nil { - tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum)) +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer + t *testing.T +} + +// match tries to match each Layer in received against the incoming filter. If +// received is longer than layerStates then that may still count as a match. The +// reverse is never a match. override overrides the default matchers for each +// Layer. +func (conn *Connection) match(override, received Layers) bool { + if len(received) < len(conn.layerStates) { + return false + } + for i, s := range conn.layerStates { + toMatch := s.incoming(received[i]) + if toMatch == nil { + return false + } + if i < len(override) { + toMatch.merge(override[i]) + } + if !toMatch.match(received[i]) { + return false + } + } + return true +} + +// Close frees associated resources held by the Connection. +func (conn *Connection) Close() { + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) + for _, s := range conn.layerStates { + if err := s.close(); err != nil { + errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) + } } - if tcp.AckNum == nil { - tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum)) + if errs != nil { + conn.t.Fatalf("unable to close %+v: %s", conn, errs) } - layersToSend := deepcopy.Copy(conn.outgoing).(Layers) - if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil { - conn.t.Fatalf("can't merge %v into %v: %s", tcp, layersToSend[tcpLayerIndex], err) +} + +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + var layersToSend Layers + for _, s := range conn.layerStates { + layersToSend = append(layersToSend, s.outgoing()) + } + if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil { + conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err) } layersToSend = append(layersToSend, additionalLayers...) - outBytes, err := layersToSend.toBytes() + return layersToSend +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(frame Layers) { + outBytes, err := frame.toBytes() if err != nil { conn.t.Fatalf("can't build outgoing TCP packet: %s", err) } conn.injector.Send(outBytes) - // Compute the next TCP sequence number. - for i := tcpLayerIndex + 1; i < len(layersToSend); i++ { - conn.LocalSeqNum.UpdateForward(seqnum.Size(layersToSend[i].length())) + // frame might have nil values where the caller wanted to use default values. + // sentFrame will have no nil values in it because it comes from parsing the + // bytes that were actually sent. + sentFrame := parse(parseEther, outBytes) + // Update the state of each layer based on what was sent. + for i, s := range conn.layerStates { + if err := s.sent(sentFrame[i]); err != nil { + conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) + } } - if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - conn.LocalSeqNum.UpdateForward(1) +} + +// Send a packet with reasonable defaults. Potentially override the final layer +// in the connection with the provided layer and add additionLayers. +func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) { + conn.SendFrame(conn.CreateFrame(layer, additionalLayers...)) +} + +// recvFrame gets the next successfully parsed frame (of type Layers) within the +// timeout provided. If no parsable frame arrives before the timeout, it returns +// nil. +func (conn *Connection) recvFrame(timeout time.Duration) Layers { + if timeout <= 0 { + return nil } + b := conn.sniffer.Recv(timeout) + if b == nil { + return nil + } + return parse(parseEther, b) } -// Recv gets a packet from the sniffer within the timeout provided. If no packet -// arrives before the timeout, it returns nil. -func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP { - deadline := time.Now().Add(timeout) - for { - timeout = deadline.Sub(time.Now()) - if timeout <= 0 { - break - } - b := conn.sniffer.Recv(timeout) - if b == nil { - break - } - layers, err := ParseEther(b) - if err != nil { - continue // Ignore packets that can't be parsed. - } - if !conn.incoming.match(layers) { - continue // Ignore packets that don't match the expected incoming. - } - tcpHeader := (layers[tcpLayerIndex]).(*TCP) - conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum) - if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - conn.RemoteSeqNum.UpdateForward(1) - } - for i := tcpLayerIndex + 1; i < len(layers); i++ { - conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length())) - } - return tcpHeader +// Expect a frame with the final layerStates layer matching the provided Layer +// within the timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { + // Make a frame that will ignore all but the final layer. + layers := make([]Layer, len(conn.layerStates)) + layers[len(layers)-1] = layer + + gotFrame, err := conn.ExpectFrame(layers, timeout) + if err != nil { + return nil, err } - return nil + if len(conn.layerStates)-1 < len(gotFrame) { + return gotFrame[len(conn.layerStates)-1], nil + } + conn.t.Fatal("the received frame should be at least as long as the expected layers") + return nil, fmt.Errorf("the received frame should be at least as long as the expected layers") } -// Expect a packet that matches the provided tcp within the timeout specified. -// If it doesn't arrive in time, the test fails. -func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP { +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { deadline := time.Now().Add(timeout) + var allLayers []string for { - timeout = deadline.Sub(time.Now()) - if timeout <= 0 { - return nil + var gotLayers Layers + if timeout = time.Until(deadline); timeout > 0 { + gotLayers = conn.recvFrame(timeout) } - gotTCP := conn.Recv(timeout) - if gotTCP == nil { - return nil + if gotLayers == nil { + return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n")) } - if tcp.match(gotTCP) { - return gotTCP + if conn.match(layers, gotLayers) { + for i, s := range conn.layerStates { + if err := s.received(gotLayers[i]); err != nil { + conn.t.Fatal(err) + } + } + return gotLayers, nil } + allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers)) + } +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain() { + conn.sniffer.Drain() +} + +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection + +// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. +func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + tcpState, err := newTCPState(outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, } } -// Handshake performs a TCP 3-way handshake. +// Handshake performs a TCP 3-way handshake. The input Connection should have a +// final TCP Layer. func (conn *TCPIPv4) Handshake() { // Send the SYN. conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if conn.SynAck == nil { - conn.t.Fatalf("didn't get synack during handshake") + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if synAck == nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) } + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(expected, timeout) +} + +// Send a packet with reasonable defaults. Potentially override the TCP layer in +// the connection with the provided layer and add additionLayers. +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + (*Connection)(conn).Send(&tcp, additionalLayers...) +} + +// Close frees associated resources held by the TCPIPv4 connection. +func (conn *TCPIPv4) Close() { + (*Connection)(conn).Close() +} + +// Expect a frame with the TCP layer matching the provided TCP within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { + layer, err := (*Connection)(conn).Expect(&tcp, timeout) + if layer == nil { + return nil, err + } + gotTCP, ok := layer.(*TCP) + if !ok { + conn.t.Fatalf("expected %s to be TCP", layer) + } + return gotTCP, err +} + +func (conn *TCPIPv4) state() *tcpState { + state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState) + if !ok { + conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates) + } + return state +} + +// RemoteSeqNum returns the next expected sequence number from the DUT. +func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { + return conn.state().remoteSeqNum +} + +// LocalSeqNum returns the next sequence number to send from the testbench. +func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { + return conn.state().localSeqNum +} + +// SynAck returns the SynAck that was part of the handshake. +func (conn *TCPIPv4) SynAck() *TCP { + return conn.state().synAck +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} + +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection + +// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. +func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + tcpState, err := newUDPState(outgoingUDP, incomingUDP) + if err != nil { + t.Fatalf("can't make udpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return UDPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, + } +} + +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + return (*Connection)(conn).CreateFrame(layer, additionalLayers...) +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *UDPIPv4) SendFrame(frame Layers) { + (*Connection)(conn).SendFrame(frame) +} + +// Close frees associated resources held by the UDPIPv4 connection. +func (conn *UDPIPv4) Close() { + (*Connection)(conn).Close() +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain() { + conn.sniffer.Drain() +} diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 8ea1706d3..9335909c0 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -65,33 +65,6 @@ func (dut *DUT) TearDown() { dut.conn.Close() } -// SocketWithErrno calls socket on the DUT and returns the fd and errno. -func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { - dut.t.Helper() - req := pb.SocketRequest{ - Domain: domain, - Type: typ, - Protocol: proto, - } - ctx := context.Background() - resp, err := dut.posixServer.Socket(ctx, &req) - if err != nil { - dut.t.Fatalf("failed to call Socket: %s", err) - } - return resp.GetFd(), syscall.Errno(resp.GetErrno_()) -} - -// Socket calls socket on the DUT and returns the file descriptor. If socket -// fails on the DUT, the test ends. -func (dut *DUT) Socket(domain, typ, proto int32) int32 { - dut.t.Helper() - fd, err := dut.SocketWithErrno(domain, typ, proto) - if fd < 0 { - dut.t.Fatalf("failed to create socket: %s", err) - } - return fd -} - func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { dut.t.Helper() switch s := sa.(type) { @@ -142,14 +115,95 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { return nil } +// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol +// proto, and bound to the IP address addr. Returns the new file descriptor and +// the port that was selected on the DUT. +func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) { + dut.t.Helper() + var fd int32 + if addr.To4() != nil { + fd = dut.Socket(unix.AF_INET, typ, proto) + sa := unix.SockaddrInet4{} + copy(sa.Addr[:], addr.To4()) + dut.Bind(fd, &sa) + } else if addr.To16() != nil { + fd = dut.Socket(unix.AF_INET6, typ, proto) + sa := unix.SockaddrInet6{} + copy(sa.Addr[:], addr.To16()) + dut.Bind(fd, &sa) + } else { + dut.t.Fatal("unknown ip addr type for remoteIP") + } + sa := dut.GetSockName(fd) + var port int + switch s := sa.(type) { + case *unix.SockaddrInet4: + port = s.Port + case *unix.SockaddrInet6: + port = s.Port + default: + dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) + } + return fd, uint16(port) +} + +// CreateListener makes a new TCP connection. If it fails, the test ends. +func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { + fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4)) + dut.Listen(fd, backlog) + return fd, remotePort +} + +// All the functions that make gRPC calls to the Posix service are below, sorted +// alphabetically. + +// Accept calls accept on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// AcceptWithErrno. +func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + fd, sa, err := dut.AcceptWithErrno(ctx, sockfd) + if fd < 0 { + dut.t.Fatalf("failed to accept: %s", err) + } + return fd, sa +} + +// AcceptWithErrno calls accept on the DUT. +func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { + dut.t.Helper() + req := pb.AcceptRequest{ + Sockfd: sockfd, + } + resp, err := dut.posixServer.Accept(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Accept: %s", err) + } + return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) +} + +// Bind calls bind on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is +// needed, use BindWithErrno. +func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.BindWithErrno(ctx, fd, sa) + if ret != 0 { + dut.t.Fatalf("failed to bind socket: %s", err) + } +} + // BindWithErrno calls bind on the DUT. -func (dut *DUT) BindWithErrno(fd int32, sa unix.Sockaddr) (int32, error) { +func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { dut.t.Helper() req := pb.BindRequest{ Sockfd: fd, Addr: dut.sockaddrToProto(sa), } - ctx := context.Background() resp, err := dut.posixServer.Bind(ctx, &req) if err != nil { dut.t.Fatalf("failed to call Bind: %s", err) @@ -157,23 +211,52 @@ func (dut *DUT) BindWithErrno(fd int32, sa unix.Sockaddr) (int32, error) { return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } -// Bind calls bind on the DUT and causes a fatal test failure if it doesn't -// succeed. -func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { +// Close calls close on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// CloseWithErrno. +func (dut *DUT) Close(fd int32) { dut.t.Helper() - ret, err := dut.BindWithErrno(fd, sa) + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.CloseWithErrno(ctx, fd) if ret != 0 { - dut.t.Fatalf("failed to bind socket: %s", err) + dut.t.Fatalf("failed to close: %s", err) } } +// CloseWithErrno calls close on the DUT. +func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { + dut.t.Helper() + req := pb.CloseRequest{ + Fd: fd, + } + resp, err := dut.posixServer.Close(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Close: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// GetSockName calls getsockname on the DUT and causes a fatal test failure if +// it doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockNameWithErrno. +func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd) + if ret != 0 { + dut.t.Fatalf("failed to getsockname: %s", err) + } + return sa +} + // GetSockNameWithErrno calls getsockname on the DUT. -func (dut *DUT) GetSockNameWithErrno(sockfd int32) (int32, unix.Sockaddr, error) { +func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { dut.t.Helper() req := pb.GetSockNameRequest{ Sockfd: sockfd, } - ctx := context.Background() resp, err := dut.posixServer.GetSockName(ctx, &req) if err != nil { dut.t.Fatalf("failed to call Bind: %s", err) @@ -181,26 +264,26 @@ func (dut *DUT) GetSockNameWithErrno(sockfd int32) (int32, unix.Sockaddr, error) return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) } -// GetSockName calls getsockname on the DUT and causes a fatal test failure if -// it doens't succeed. -func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { +// Listen calls listen on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// ListenWithErrno. +func (dut *DUT) Listen(sockfd, backlog int32) { dut.t.Helper() - ret, sa, err := dut.GetSockNameWithErrno(sockfd) + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.ListenWithErrno(ctx, sockfd, backlog) if ret != 0 { - dut.t.Fatalf("failed to getsockname: %s", err) + dut.t.Fatalf("failed to listen: %s", err) } - return sa } // ListenWithErrno calls listen on the DUT. -func (dut *DUT) ListenWithErrno(sockfd, backlog int32) (int32, error) { +func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) { dut.t.Helper() req := pb.ListenRequest{ Sockfd: sockfd, Backlog: backlog, } - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) - defer cancel() resp, err := dut.posixServer.Listen(ctx, &req) if err != nil { dut.t.Fatalf("failed to call Listen: %s", err) @@ -208,44 +291,54 @@ func (dut *DUT) ListenWithErrno(sockfd, backlog int32) (int32, error) { return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } -// Listen calls listen on the DUT and causes a fatal test failure if it doesn't -// succeed. -func (dut *DUT) Listen(sockfd, backlog int32) { +// Send calls send on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// SendWithErrno. +func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 { dut.t.Helper() - ret, err := dut.ListenWithErrno(sockfd, backlog) - if ret != 0 { - dut.t.Fatalf("failed to listen: %s", err) + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags) + if ret == -1 { + dut.t.Fatalf("failed to send: %s", err) } + return ret } -// AcceptWithErrno calls accept on the DUT. -func (dut *DUT) AcceptWithErrno(sockfd int32) (int32, unix.Sockaddr, error) { +// SendWithErrno calls send on the DUT. +func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) { dut.t.Helper() - req := pb.AcceptRequest{ + req := pb.SendRequest{ Sockfd: sockfd, + Buf: buf, + Flags: flags, } - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) - defer cancel() - resp, err := dut.posixServer.Accept(ctx, &req) + resp, err := dut.posixServer.Send(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Accept: %s", err) + dut.t.Fatalf("failed to call Send: %s", err) } - return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } -// Accept calls accept on the DUT and causes a fatal test failure if it doesn't -// succeed. -func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { +// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptWithErrno. Because endianess and the width of values +// might differ between the testbench and DUT architectures, prefer to use a +// more specific SetSockOptXxx function. +func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { dut.t.Helper() - fd, sa, err := dut.AcceptWithErrno(sockfd) - if fd < 0 { - dut.t.Fatalf("failed to accept: %s", err) + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOpt: %s", err) } - return fd, sa } -// SetSockOptWithErrno calls setsockopt on the DUT. -func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte) (int32, error) { +// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the +// width of values might differ between the testbench and DUT architectures, +// prefer to use a more specific SetSockOptXxxWithErrno function. +func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) { dut.t.Helper() req := pb.SetSockOptRequest{ Sockfd: sockfd, @@ -253,8 +346,6 @@ func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte) Optname: optname, Optval: optval, } - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) - defer cancel() resp, err := dut.posixServer.SetSockOpt(ctx, &req) if err != nil { dut.t.Fatalf("failed to call SetSockOpt: %s", err) @@ -262,19 +353,51 @@ func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte) return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } -// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it -// doesn't succeed. -func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { +// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the int optval or error handling +// is needed, use SetSockOptIntWithErrno. +func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) { dut.t.Helper() - ret, err := dut.SetSockOptWithErrno(sockfd, level, optname, optval) + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval) if ret != 0 { - dut.t.Fatalf("failed to SetSockOpt: %s", err) + dut.t.Fatalf("failed to SetSockOptInt: %s", err) + } +} + +// SetSockOptIntWithErrno calls setsockopt with an integer optval. +func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) { + dut.t.Helper() + req := pb.SetSockOptIntRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Intval: optval, + } + resp, err := dut.posixServer.SetSockOptInt(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call SetSockOptInt: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptTimevalWithErrno. +func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOptTimeval: %s", err) } } // SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to // bytes. -func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { +func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { dut.t.Helper() timeval := pb.Timeval{ Seconds: int64(tv.Sec), @@ -286,8 +409,6 @@ func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *uni Optname: optname, Timeval: &timeval, } - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) - defer cancel() resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req) if err != nil { dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err) @@ -295,69 +416,58 @@ func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *uni return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } -// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure -// if it doesn't succeed. -func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { +// Socket calls socket on the DUT and returns the file descriptor. If socket +// fails on the DUT, the test ends. +func (dut *DUT) Socket(domain, typ, proto int32) int32 { dut.t.Helper() - ret, err := dut.SetSockOptTimevalWithErrno(sockfd, level, optname, tv) - if ret != 0 { - dut.t.Fatalf("failed to SetSockOptTimeval: %s", err) + fd, err := dut.SocketWithErrno(domain, typ, proto) + if fd < 0 { + dut.t.Fatalf("failed to create socket: %s", err) } + return fd } -// CloseWithErrno calls close on the DUT. -func (dut *DUT) CloseWithErrno(fd int32) (int32, error) { +// SocketWithErrno calls socket on the DUT and returns the fd and errno. +func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { dut.t.Helper() - req := pb.CloseRequest{ - Fd: fd, + req := pb.SocketRequest{ + Domain: domain, + Type: typ, + Protocol: proto, } - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) - defer cancel() - resp, err := dut.posixServer.Close(ctx, &req) + ctx := context.Background() + resp, err := dut.posixServer.Socket(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Close: %s", err) + dut.t.Fatalf("failed to call Socket: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetFd(), syscall.Errno(resp.GetErrno_()) } -// Close calls close on the DUT and causes a fatal test failure if it doesn't -// succeed. -func (dut *DUT) Close(fd int32) { +// Recv calls recv on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// RecvWithErrno. +func (dut *DUT) Recv(sockfd, len, flags int32) []byte { dut.t.Helper() - ret, err := dut.CloseWithErrno(fd) - if ret != 0 { - dut.t.Fatalf("failed to close: %s", err) + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) + if ret == -1 { + dut.t.Fatalf("failed to recv: %s", err) } + return buf } -// CreateListener makes a new TCP connection. If it fails, the test ends. -func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { +// RecvWithErrno calls recv on the DUT. +func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) { dut.t.Helper() - addr := net.ParseIP(*remoteIPv4) - var fd int32 - if addr.To4() != nil { - fd = dut.Socket(unix.AF_INET, typ, proto) - sa := unix.SockaddrInet4{} - copy(sa.Addr[:], addr.To4()) - dut.Bind(fd, &sa) - } else if addr.To16() != nil { - fd = dut.Socket(unix.AF_INET6, typ, proto) - sa := unix.SockaddrInet6{} - copy(sa.Addr[:], addr.To16()) - dut.Bind(fd, &sa) - } else { - dut.t.Fatal("unknown ip addr type for remoteIP") + req := pb.RecvRequest{ + Sockfd: sockfd, + Len: len, + Flags: flags, } - sa := dut.GetSockName(fd) - var port int - switch s := sa.(type) { - case *unix.SockaddrInet4: - port = s.Port - case *unix.SockaddrInet6: - port = s.Port - default: - dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) + resp, err := dut.posixServer.Recv(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Recv: %s", err) } - dut.Listen(fd, backlog) - return fd, uint16(port) + return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 35fa4dcb6..5ce324f0d 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -15,13 +15,16 @@ package testbench import ( + "encoding/hex" "fmt" "reflect" + "strings" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/imdario/mergo" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -31,6 +34,8 @@ import ( // Layer contains all the fields of the encapsulation. Each field is a pointer // and may be nil. type Layer interface { + fmt.Stringer + // toBytes converts the Layer into bytes. In places where the Layer's field // isn't nil, the value that is pointed to is used. When the field is nil, a // reasonable default for the Layer is used. For example, "64" for IPv4 TTL @@ -42,7 +47,8 @@ type Layer interface { // match checks if the current Layer matches the provided Layer. If either // Layer has a nil in a given field, that field is considered matching. - // Otherwise, the values pointed to by the fields must match. + // Otherwise, the values pointed to by the fields must match. The LayerBase is + // ignored. match(Layer) bool // length in bytes of the current encapsulation @@ -59,6 +65,9 @@ type Layer interface { // setPrev sets the pointer to the Layer encapsulating this one. setPrev(Layer) + + // merge overrides the values in the interface with the provided values. + merge(Layer) error } // LayerBase is the common elements of all layers. @@ -83,21 +92,59 @@ func (lb *LayerBase) setPrev(l Layer) { lb.prevLayer = l } +// equalLayer compares that two Layer structs match while ignoring field in +// which either input has a nil and also ignoring the LayerBase of the inputs. func equalLayer(x, y Layer) bool { + if x == nil || y == nil { + return true + } + // opt ignores comparison pairs where either of the inputs is a nil. opt := cmp.FilterValues(func(x, y interface{}) bool { - if reflect.ValueOf(x).Kind() == reflect.Ptr && reflect.ValueOf(x).IsNil() { - return true - } - if reflect.ValueOf(y).Kind() == reflect.Ptr && reflect.ValueOf(y).IsNil() { - return true + for _, l := range []interface{}{x, y} { + v := reflect.ValueOf(l) + if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() { + return true + } } return false - }, cmp.Ignore()) - return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{})) + return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) +} + +// mergeLayer merges other in layer. Any non-nil value in other overrides the +// corresponding value in layer. If other is nil, no action is performed. +func mergeLayer(layer, other Layer) error { + if other == nil { + return nil + } + return mergo.Merge(layer, other, mergo.WithOverride) +} + +func stringLayer(l Layer) string { + v := reflect.ValueOf(l).Elem() + t := v.Type() + var ret []string + for i := 0; i < v.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + v := v.Field(i) + if v.IsNil() { + continue + } + v = reflect.Indirect(v) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) + } else { + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } + } + return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) } -// Ether can construct and match the ethernet encapsulation. +// Ether can construct and match an ethernet encapsulation. type Ether struct { LayerBase SrcAddr *tcpip.LinkAddress @@ -105,6 +152,10 @@ type Ether struct { Type *tcpip.NetworkProtocolNumber } +func (l *Ether) String() string { + return stringLayer(l) +} + func (l *Ether) toBytes() ([]byte, error) { b := make([]byte, header.EthernetMinimumSize) h := header.Ethernet(b) @@ -123,7 +174,7 @@ func (l *Ether) toBytes() ([]byte, error) { fields.Type = header.IPv4ProtocolNumber default: // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n) + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) } } h.Encode(fields) @@ -142,27 +193,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol return &v } -// ParseEther parses the bytes assuming that they start with an ethernet header +// layerParser parses the input bytes and returns a Layer along with the next +// layerParser to run. If there is no more parsing to do, the returned +// layerParser is nil. +type layerParser func([]byte) (Layer, layerParser) + +// parse parses bytes starting with the first layerParser and using successive +// layerParsers until all the bytes are parsed. +func parse(parser layerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + +// parseEther parses the bytes assuming that they start with an ethernet header // and continues parsing further encapsulations. -func ParseEther(b []byte) (Layers, error) { +func parseEther(b []byte) (Layer, layerParser) { h := header.Ethernet(b) ether := Ether{ SrcAddr: LinkAddress(h.SourceAddress()), DstAddr: LinkAddress(h.DestinationAddress()), Type: NetworkProtocolNumber(h.Type()), } - layers := Layers{ðer} + var nextParser layerParser switch h.Type() { case header.IPv4ProtocolNumber: - moreLayers, err := ParseIPv4(b[ether.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseIPv4 default: - // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %v", b) + // Assume that the rest is a payload. + nextParser = parsePayload } + return ðer, nextParser } func (l *Ether) match(other Layer) bool { @@ -173,7 +243,13 @@ func (l *Ether) length() int { return header.EthernetMinimumSize } -// IPv4 can construct and match the ethernet excapulation. +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Ether) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv4 can construct and match an IPv4 encapsulation. type IPv4 struct { LayerBase IHL *uint8 @@ -189,6 +265,10 @@ type IPv4 struct { DstAddr *tcpip.Address } +func (l *IPv4) String() string { + return stringLayer(l) +} + func (l *IPv4) toBytes() ([]byte, error) { b := make([]byte, header.IPv4MinimumSize) h := header.IPv4(b) @@ -236,9 +316,11 @@ func (l *IPv4) toBytes() ([]byte, error) { switch n := l.next().(type) { case *TCP: fields.Protocol = uint8(header.TCPProtocolNumber) + case *UDP: + fields.Protocol = uint8(header.UDPProtocolNumber) default: - // TODO(b/150301488): Support more protocols, like UDP. - return nil, fmt.Errorf("can't deduce the ip header's next protocol: %+v", n) + // TODO(b/150301488): Support more protocols as needed. + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) } } if l.SrcAddr != nil { @@ -275,9 +357,9 @@ func Address(v tcpip.Address) *tcpip.Address { return &v } -// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and +// parseIPv4 parses the bytes assuming that they start with an ipv4 header and // continues parsing further encapsulations. -func ParseIPv4(b []byte) (Layers, error) { +func parseIPv4(b []byte) (Layer, layerParser) { h := header.IPv4(b) tos, _ := h.TOS() ipv4 := IPv4{ @@ -293,16 +375,17 @@ func ParseIPv4(b []byte) (Layers, error) { SrcAddr: Address(h.SourceAddress()), DstAddr: Address(h.DestinationAddress()), } - layers := Layers{&ipv4} - switch h.Protocol() { - case uint8(header.TCPProtocolNumber): - moreLayers, err := ParseTCP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + var nextParser layerParser + switch h.TransportProtocol() { + case header.TCPProtocolNumber: + nextParser = parseTCP + case header.UDPProtocolNumber: + nextParser = parseUDP + default: + // Assume that the rest is a payload. + nextParser = parsePayload } - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol()) + return &ipv4, nextParser } func (l *IPv4) match(other Layer) bool { @@ -316,7 +399,13 @@ func (l *IPv4) length() int { return int(*l.IHL) } -// TCP can construct and match the TCP excapulation. +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + +// TCP can construct and match a TCP encapsulation. type TCP struct { LayerBase SrcPort *uint16 @@ -330,6 +419,10 @@ type TCP struct { UrgentPointer *uint16 } +func (l *TCP) String() string { + return stringLayer(l) +} + func (l *TCP) toBytes() ([]byte, error) { b := make([]byte, header.TCPMinimumSize) h := header.TCP(b) @@ -347,12 +440,16 @@ func (l *TCP) toBytes() ([]byte, error) { } if l.DataOffset != nil { h.SetDataOffset(*l.DataOffset) + } else { + h.SetDataOffset(uint8(l.length())) } if l.Flags != nil { h.SetFlags(*l.Flags) } if l.WindowSize != nil { h.SetWindowSize(*l.WindowSize) + } else { + h.SetWindowSize(32768) } if l.UrgentPointer != nil { h.SetUrgentPoiner(*l.UrgentPointer) @@ -361,38 +458,52 @@ func (l *TCP) toBytes() ([]byte, error) { h.SetChecksum(*l.Checksum) return h, nil } - if err := setChecksum(&h, l); err != nil { + if err := setTCPChecksum(&h, l); err != nil { return nil, err } return h, nil } -// setChecksum calculates the checksum of the TCP header and sets it in h. -func setChecksum(h *header.TCP, tcp *TCP) error { - h.SetChecksum(0) - tcpLength := uint16(tcp.length()) - current := tcp.next() - for current != nil { - tcpLength += uint16(current.length()) - current = current.next() +// totalLength returns the length of the provided layer and all following +// layers. +func totalLength(l Layer) int { + var totalLength int + for ; l != nil; l = l.next() { + totalLength += l.length() } + return totalLength +} +// layerChecksum calculates the checksum of the Layer header, including the +// peusdeochecksum of the layer before it and all the bytes after it.. +func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { + totalLength := uint16(totalLength(l)) var xsum uint16 - switch s := tcp.prev().(type) { + switch s := l.prev().(type) { case *IPv4: - xsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, *s.SrcAddr, *s.DstAddr, tcpLength) + xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength) default: // TODO(b/150301488): Support more protocols, like IPv6. - return fmt.Errorf("can't get src and dst addr from previous layer") + return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s) } - current = tcp.next() - for current != nil { + var payloadBytes buffer.VectorisedView + for current := l.next(); current != nil; current = current.next() { payload, err := current.toBytes() if err != nil { - return fmt.Errorf("can't get bytes for next header: %s", payload) + return 0, fmt.Errorf("can't get bytes for next header: %s", payload) } - xsum = header.Checksum(payload, xsum) - current = current.next() + payloadBytes.AppendView(payload) + } + xsum = header.ChecksumVV(payloadBytes, xsum) + return xsum, nil +} + +// setTCPChecksum calculates the checksum of the TCP header and sets it in h. +func setTCPChecksum(h *header.TCP, tcp *TCP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(tcp, header.TCPProtocolNumber) + if err != nil { + return err } h.SetChecksum(^h.CalculateChecksum(xsum)) return nil @@ -404,9 +515,9 @@ func Uint32(v uint32) *uint32 { return &v } -// ParseTCP parses the bytes assuming that they start with a tcp header and +// parseTCP parses the bytes assuming that they start with a tcp header and // continues parsing further encapsulations. -func ParseTCP(b []byte) (Layers, error) { +func parseTCP(b []byte) (Layer, layerParser) { h := header.TCP(b) tcp := TCP{ SrcPort: Uint16(h.SourcePort()), @@ -419,12 +530,7 @@ func ParseTCP(b []byte) (Layers, error) { Checksum: Uint16(h.Checksum()), UrgentPointer: Uint16(h.UrgentPointer()), } - layers := Layers{&tcp} - moreLayers, err := ParsePayload(b[tcp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &tcp, parsePayload } func (l *TCP) match(other Layer) bool { @@ -440,8 +546,86 @@ func (l *TCP) length() int { // merge overrides the values in l with the values from other but only in fields // where the value is not nil. -func (l *TCP) merge(other TCP) error { - return mergo.Merge(l, other, mergo.WithOverride) +func (l *TCP) merge(other Layer) error { + return mergeLayer(l, other) +} + +// UDP can construct and match a UDP encapsulation. +type UDP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + Length *uint16 + Checksum *uint16 +} + +func (l *UDP) String() string { + return stringLayer(l) +} + +func (l *UDP) toBytes() ([]byte, error) { + b := make([]byte, header.UDPMinimumSize) + h := header.UDP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.Length != nil { + h.SetLength(*l.Length) + } else { + h.SetLength(uint16(totalLength(l))) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setUDPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// setUDPChecksum calculates the checksum of the UDP header and sets it in h. +func setUDPChecksum(h *header.UDP, udp *UDP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(udp, header.UDPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// parseUDP parses the bytes assuming that they start with a udp header and +// returns the parsed layer and the next parser to use. +func parseUDP(b []byte) (Layer, layerParser) { + h := header.UDP(b) + udp := UDP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + Length: Uint16(h.Length()), + Checksum: Uint16(h.Checksum()), + } + return &udp, parsePayload +} + +func (l *UDP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *UDP) length() int { + if l.Length == nil { + return header.UDPMinimumSize + } + return int(*l.Length) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *UDP) merge(other Layer) error { + return mergeLayer(l, other) } // Payload has bytes beyond OSI layer 4. @@ -450,13 +634,17 @@ type Payload struct { Bytes []byte } -// ParsePayload parses the bytes assuming that they start with a payload and +func (l *Payload) String() string { + return stringLayer(l) +} + +// parsePayload parses the bytes assuming that they start with a payload and // continue to the end. There can be no further encapsulations. -func ParsePayload(b []byte) (Layers, error) { +func parsePayload(b []byte) (Layer, layerParser) { payload := Payload{ Bytes: b, } - return Layers{&payload}, nil + return &payload, nil } func (l *Payload) toBytes() ([]byte, error) { @@ -471,18 +659,33 @@ func (l *Payload) length() int { return len(l.Bytes) } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Payload) merge(other Layer) error { + return mergeLayer(l, other) +} + // Layers is an array of Layer and supports similar functions to Layer. type Layers []Layer -func (ls *Layers) toBytes() ([]byte, error) { +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { for i, l := range *ls { if i > 0 { l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) } if i+1 < len(*ls) { l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) } } +} + +func (ls *Layers) toBytes() ([]byte, error) { + ls.linkLayers() outBytes := []byte{} for _, l := range *ls { layerBytes, err := l.toBytes() @@ -498,8 +701,8 @@ func (ls *Layers) match(other Layers) bool { if len(*ls) > len(other) { return false } - for i := 0; i < len(*ls); i++ { - if !equalLayer((*ls)[i], other[i]) { + for i, l := range *ls { + if !equalLayer(l, other[i]) { return false } } diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go new file mode 100644 index 000000000..b32efda93 --- /dev/null +++ b/test/packetimpact/testbench/layers_test.go @@ -0,0 +1,156 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestLayerMatch(t *testing.T) { + var nilPayload *Payload + noPayload := &Payload{} + emptyPayload := &Payload{Bytes: []byte{}} + fullPayload := &Payload{Bytes: []byte{1, 2, 3}} + emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}} + fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}} + for _, tt := range []struct { + a, b Layer + want bool + }{ + {nilPayload, nilPayload, true}, + {nilPayload, noPayload, true}, + {nilPayload, emptyPayload, true}, + {nilPayload, fullPayload, true}, + {noPayload, noPayload, true}, + {noPayload, emptyPayload, true}, + {noPayload, fullPayload, true}, + {emptyPayload, emptyPayload, true}, + {emptyPayload, fullPayload, false}, + {fullPayload, fullPayload, true}, + {emptyTCP, fullTCP, true}, + } { + if got := tt.a.match(tt.b); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want) + } + if got := tt.b.match(tt.a); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want) + } + } +} + +func TestLayerStringFormat(t *testing.T) { + for _, tt := range []struct { + name string + l Layer + want string + }{ + { + name: "TCP", + l: &TCP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + SeqNum: Uint32(3452155723), + AckNum: Uint32(2596996163), + DataOffset: Uint8(5), + Flags: Uint8(20), + WindowSize: Uint16(64240), + Checksum: Uint16(0x2e2b), + }, + want: "&testbench.TCP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "SeqNum:3452155723 " + + "AckNum:2596996163 " + + "DataOffset:5 " + + "Flags:20 " + + "WindowSize:64240 " + + "Checksum:11819" + + "}", + }, + { + name: "UDP", + l: &UDP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + Length: Uint16(12), + }, + want: "&testbench.UDP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "Length:12" + + "}", + }, + { + name: "IPv4", + l: &IPv4{ + IHL: Uint8(5), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(0), + Flags: Uint8(2), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(6), + Checksum: Uint16(0x2e2b), + SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), + DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), + }, + want: "&testbench.IPv4{" + + "IHL:5 " + + "TOS:0 " + + "TotalLength:44 " + + "ID:0 " + + "Flags:2 " + + "FragmentOffset:0 " + + "TTL:64 " + + "Protocol:6 " + + "Checksum:11819 " + + "SrcAddr:197.34.63.10 " + + "DstAddr:197.34.63.20" + + "}", + }, + { + name: "Ether", + l: &Ether{ + SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), + DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), + Type: NetworkProtocolNumber(4), + }, + want: "&testbench.Ether{" + + "SrcAddr:02:42:c5:22:3f:0a " + + "DstAddr:02:42:c5:22:3f:14 " + + "Type:4" + + "}", + }, + { + name: "Payload", + l: &Payload{ + Bytes: []byte("Hooray for packetimpact."), + }, + want: "&testbench.Payload{Bytes:\n" + + "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" + + "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" + + "}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.String(); got != tt.want { + t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 0c7d0f979..ff722d4a6 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -17,6 +17,7 @@ package testbench import ( "encoding/binary" "flag" + "fmt" "math" "net" "testing" @@ -47,6 +48,12 @@ func NewSniffer(t *testing.T) (Sniffer, error) { if err != nil { return Sniffer{}, err } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil { + t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err) + } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil { + t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) + } return Sniffer{ t: t, fd: snifferFd, @@ -91,12 +98,36 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { } } -// Close the socket that Sniffer is using. -func (s *Sniffer) Close() { +// Drain drains the Sniffer's socket receive buffer by receiving until there's +// nothing else to receive. +func (s *Sniffer) Drain() { + s.t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) + if err != nil { + s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { + s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + } + for { + buf := make([]byte, maxReadSize) + _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { + s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + } +} + +// close the socket that Sniffer is using. +func (s *Sniffer) close() error { if err := unix.Close(s.fd); err != nil { - s.t.Fatalf("can't close sniffer socket: %s", err) + return fmt.Errorf("can't close sniffer socket: %w", err) } s.fd = -1 + return nil } // Injector can inject raw frames. @@ -142,10 +173,11 @@ func (i *Injector) Send(b []byte) { } } -// Close the underlying socket. -func (i *Injector) Close() { +// close the underlying socket. +func (i *Injector) close() error { if err := unix.Close(i.fd); err != nil { - i.t.Fatalf("can't close sniffer socket: %s", err) + return fmt.Errorf("can't close sniffer socket: %w", err) } i.fd = -1 + return nil } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 1dff2a4d5..47c722ccd 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -15,6 +15,87 @@ packetimpact_go_test( ], ) +packetimpact_go_test( + name = "udp_recv_multicast", + srcs = ["udp_recv_multicast_test.go"], + # TODO(b/152813495): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_window_shrink", + srcs = ["tcp_window_shrink_test.go"], + # TODO(b/153202472): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_outside_the_window", + srcs = ["tcp_outside_the_window_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_noaccept_close_rst", + srcs = ["tcp_noaccept_close_rst_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_should_piggyback", + srcs = ["tcp_should_piggyback_test.go"], + # TODO(b/153680566): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_close_wait_ack", + srcs = ["tcp_close_wait_ack_test.go"], + # TODO(b/153574037): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_user_timeout", + srcs = ["tcp_user_timeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + sh_binary( name = "test_runner", srcs = ["test_runner.sh"], diff --git a/test/packetimpact/tests/Dockerfile b/test/packetimpact/tests/Dockerfile index 507030cc7..9075bc555 100644 --- a/test/packetimpact/tests/Dockerfile +++ b/test/packetimpact/tests/Dockerfile @@ -1,5 +1,17 @@ FROM ubuntu:bionic -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y iptables netcat tcpdump iproute2 tshark +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + # iptables to disable OS native packet processing. + iptables \ + # nc to check that the posix_server is running. + netcat \ + # tcpdump to log brief packet sniffing. + tcpdump \ + # ip link show to display MAC addresses. + iproute2 \ + # tshark to log verbose packet sniffing. + tshark \ + # killall for cleanup. + psmisc RUN hash -r CMD /bin/bash diff --git a/test/packetimpact/tests/defs.bzl b/test/packetimpact/tests/defs.bzl index 1b4213d9b..8c0d058b2 100644 --- a/test/packetimpact/tests/defs.bzl +++ b/test/packetimpact/tests/defs.bzl @@ -93,7 +93,17 @@ def packetimpact_netstack_test(name, testbench_binary, **kwargs): **kwargs ) -def packetimpact_go_test(name, size = "small", pure = True, **kwargs): +def packetimpact_go_test(name, size = "small", pure = True, linux = True, netstack = True, **kwargs): + """Add packetimpact tests written in go. + + Args: + name: name of the test + size: size of the test + pure: make a static go binary + linux: generate a linux test + netstack: generate a netstack test + **kwargs: all the other args, forwarded to go_test + """ testbench_binary = name + "_test" go_test( name = testbench_binary, @@ -102,5 +112,7 @@ def packetimpact_go_test(name, size = "small", pure = True, **kwargs): tags = PACKETIMPACT_TAGS, **kwargs ) - packetimpact_linux_test(name = name, testbench_binary = testbench_binary) - packetimpact_netstack_test(name = name, testbench_binary = testbench_binary) + if linux: + packetimpact_linux_test(name = name, testbench_binary = testbench_binary) + if netstack: + packetimpact_netstack_test(name = name, testbench_binary = testbench_binary) diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 5f54e67ed..b98594f94 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -36,7 +36,7 @@ func TestFinWait2Timeout(t *testing.T) { defer dut.TearDown() listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) defer dut.Close(listenFd) - conn := tb.NewTCPIPv4(t, dut, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) defer conn.Close() conn.Handshake() @@ -47,20 +47,22 @@ func TestFinWait2Timeout(t *testing.T) { } dut.Close(acceptFd) - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); gotOne == nil { - t.Fatal("expected a FIN-ACK within 1 second but got none") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) } conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) time.Sleep(5 * time.Second) + conn.Drain() + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) if tt.linger2 { - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); gotOne == nil { - t.Fatal("expected a RST packet within a second but got none") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST packet within a second but got none: %s", err) } } else { - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); gotOne != nil { - t.Fatal("expected no RST packets within ten seconds but got one") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); err == nil { + t.Fatalf("expected no RST packets within ten seconds but got one: %s", err) } } }) diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go new file mode 100644 index 000000000..eb4cc7a65 --- /dev/null +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -0,0 +1,102 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_close_wait_ack_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestCloseWaitAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {"OTW", GenerateOTWSeqSegment, 0, false}, + {"OTW", GenerateOTWSeqSegment, 1, true}, + {"OTW", GenerateOTWSeqSegment, 2, true}, + {"ACK", GenerateUnaccACKSegment, 0, false}, + {"ACK", GenerateUnaccACKSegment, 1, true}, + {"ACK", GenerateUnaccACKSegment, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + + // Send a FIN to DUT to intiate the active close + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + + // Send a segment with OTW Seq / unacc ACK and expect an ACK back + conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset), &tb.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(acceptFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the +// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK +// is expected from the receiver. +func GenerateOTWSeqSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + lastAcceptable := conn.LocalSeqNum().Add(windowSize - 1) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{SeqNum: tb.Uint32(otwSeq), Flags: tb.Uint8(header.TCPFlagAck)} +} + +// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated +// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is +// expected from the receiver. +func GenerateUnaccACKSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + lastAcceptable := conn.RemoteSeqNum() + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{AckNum: tb.Uint32(unaccAck), Flags: tb.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go new file mode 100644 index 000000000..7ebdd1950 --- /dev/null +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_noaccept_close_rst_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestTcpNoAcceptCloseReset(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + conn.Handshake() + defer conn.Close() + dut.Close(listenFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + t.Fatalf("expected a RST-ACK packet but got none: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go new file mode 100644 index 000000000..db3d3273b --- /dev/null +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_outside_the_window_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive +// that are inside or outside the TCP window. Packets that are outside the +// window should force an extra ACK, as described in RFC793 page 69: +// https://tools.ietf.org/html/rfc793#page-69 +func TestTCPOutsideTheWindow(t *testing.T) { + for _, tt := range []struct { + description string + tcpFlags uint8 + payload []tb.Layer + seqNumOffset seqnum.Size + expectACK bool + }{ + {"SYN", header.TCPFlagSyn, nil, 0, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true}, + {"ACK", header.TCPFlagAck, nil, 0, false}, + {"FIN", header.TCPFlagFin, nil, 0, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 0, true}, + + {"SYN", header.TCPFlagSyn, nil, 1, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true}, + {"ACK", header.TCPFlagAck, nil, 1, true}, + {"FIN", header.TCPFlagFin, nil, 1, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 1, true}, + + {"SYN", header.TCPFlagSyn, nil, 2, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true}, + {"ACK", header.TCPFlagAck, nil, 2, true}, + {"FIN", header.TCPFlagFin, nil, 2, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Handshake() + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset + conn.Drain() + // Ignore whatever incrementing that this out-of-order packet might cause + // to the AckNum. + localSeqNum := tb.Uint32(uint32(*conn.LocalSeqNum())) + conn.Send(tb.TCP{ + Flags: tb.Uint8(tt.tcpFlags), + SeqNum: tb.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + }, tt.payload...) + timeout := 3 * time.Second + gotACK, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + if tt.expectACK && err != nil { + t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) + } + if !tt.expectACK && gotACK != nil { + t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go new file mode 100644 index 000000000..b0be6ba23 --- /dev/null +++ b/test/packetimpact/tests/tcp_should_piggyback_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_should_piggyback_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestPiggyback(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort, WindowSize: tb.Uint16(12)}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + + dut.Send(acceptFd, sampleData, 0) + expectedTCP := tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} + expectedPayload := tb.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } + + // Cause DUT to send us more data as soon as we ACK their first data segment because we have + // a small window. + dut.Send(acceptFd, sampleData, 0) + + // DUT should ACK our segment by piggybacking ACK to their outstanding data segment instead of + // sending a separate ACK packet. + conn.Send(expectedTCP, &expectedPayload) + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } +} diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go new file mode 100644 index 000000000..3cf82badb --- /dev/null +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_user_timeout_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func sendPayload(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error { + sampleData := make([]byte, 100) + for i := range sampleData { + sampleData[i] = uint8(i) + } + conn.Drain() + dut.Send(fd, sampleData, 0) + if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &tb.Payload{Bytes: sampleData}, time.Second); err != nil { + return fmt.Errorf("expected data but got none: %w", err) + } + return nil +} + +func sendFIN(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error { + dut.Close(fd) + return nil +} + +func TestTCPUserTimeout(t *testing.T) { + for _, tt := range []struct { + description string + userTimeout time.Duration + sendDelay time.Duration + }{ + {"NoUserTimeout", 0, 3 * time.Second}, + {"ACKBeforeUserTimeout", 5 * time.Second, 4 * time.Second}, + {"ACKAfterUserTimeout", 5 * time.Second, 7 * time.Second}, + } { + for _, ttf := range []struct { + description string + f func(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error + }{ + {"AfterPayload", sendPayload}, + {"AfterFIN", sendFIN}, + } { + t.Run(tt.description+ttf.description, func(t *testing.T) { + // Create a socket, listen, TCP handshake, and accept. + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Handshake() + acceptFD, _ := dut.Accept(listenFD) + + if tt.userTimeout != 0 { + dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) + } + + if err := ttf.f(&conn, &dut, acceptFD); err != nil { + t.Fatal(err) + } + + time.Sleep(tt.sendDelay) + conn.Drain() + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + + // If TCP_USER_TIMEOUT was set and the above delay was longer than the + // TCP_USER_TIMEOUT then the DUT should send a RST in response to the + // testbench's packet. + expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout + expectTimeout := 5 * time.Second + got, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, expectTimeout) + if expectRST && err != nil { + t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) + } + if !expectRST && got != nil { + t.Errorf("expected no RST packet within %s but got one: %s", expectTimeout, got) + } + }) + } + } +} diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go new file mode 100644 index 000000000..c9354074e --- /dev/null +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -0,0 +1,68 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_window_shrink_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestWindowShrink(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &tb.Payload{Bytes: sampleData} + + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + + dut.Send(acceptFd, sampleData, 0) + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + // We close our receiving window here + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)}) + + dut.Send(acceptFd, []byte("Sample Data"), 0) + // Note: There is another kind of zero-window probing which Windows uses (by sending one + // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change + // the following lines. + expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 + if _, err := conn.ExpectData(&tb.TCP{SeqNum: tb.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %v: %s", expectedRemoteSeqNum, err) + } +} diff --git a/test/packetimpact/tests/test_runner.sh b/test/packetimpact/tests/test_runner.sh index 5281cb53d..e99fc7d09 100755 --- a/test/packetimpact/tests/test_runner.sh +++ b/test/packetimpact/tests/test_runner.sh @@ -29,13 +29,15 @@ function failure() { } trap 'failure ${LINENO} "$BASH_COMMAND"' ERR -declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark" +declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark,extra_test_arg:" # Don't use declare below so that the error from getopt will end the script. PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@") eval set -- "$PARSED" +declare -a EXTRA_TEST_ARGS + while true; do case "$1" in --dut_platform) @@ -62,6 +64,10 @@ while true; do declare -r TSHARK="1" shift 1 ;; + --extra_test_arg) + EXTRA_TEST_ARGS+="$2" + shift 2 + ;; --) shift break @@ -125,6 +131,19 @@ docker --version function finish { local cleanup_success=1 + + if [[ -z "${TSHARK-}" ]]; then + # Kill tcpdump so that it will flush output. + docker exec -t "${TESTBENCH}" \ + killall tcpdump || \ + cleanup_success=0 + else + # Kill tshark so that it will flush output. + docker exec -t "${TESTBENCH}" \ + killall tshark || \ + cleanup_success=0 + fi + for net in "${CTRL_NET}" "${TEST_NET}"; do # Kill all processes attached to ${net}. for docker_command in "kill" "rm"; do @@ -224,6 +243,8 @@ else # interface with the test packets. docker exec -t "${TESTBENCH}" \ tshark -V -l -n -i "${TEST_DEVICE}" \ + -o tcp.check_checksum:TRUE \ + -o udp.check_checksum:TRUE \ host "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" & fi @@ -235,6 +256,7 @@ sleep 3 # be executed on the DUT. docker exec -t "${TESTBENCH}" \ /bin/bash -c "${DOCKER_TESTBENCH_BINARY} \ + ${EXTRA_TEST_ARGS[@]-} \ --posix_server_ip=${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \ --posix_server_port=${CTRL_PORT} \ --remote_ipv4=${TEST_NET_PREFIX}${DUT_NET_SUFFIX} \ diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go new file mode 100644 index 000000000..61fd17050 --- /dev/null +++ b/test/packetimpact/tests/udp_recv_multicast_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_recv_multicast_test + +import ( + "net" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestUDPRecvMulticast(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(boundFD) + conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort}) + defer conn.Close() + frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")}) + frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4())) + conn.SendFrame(frame) + dut.Recv(boundFD, 100, 0) +} diff --git a/test/perf/BUILD b/test/perf/BUILD index 0a0def6a3..471d8c2ab 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -30,6 +30,7 @@ syscall_test( syscall_test( size = "enormous", + shard_count = 10, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", ) diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc index afc599ad2..d8e81fa8c 100644 --- a/test/perf/linux/getdents_benchmark.cc +++ b/test/perf/linux/getdents_benchmark.cc @@ -38,7 +38,7 @@ namespace testing { namespace { -constexpr int kBufferSize = 16384; +constexpr int kBufferSize = 65536; PosixErrorOr<TempPath> CreateDirectory(int count, std::vector<std::string>* files) { diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index 4038661cb..679342def 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -53,7 +53,7 @@ func verifyPid(pid int, path string) error { if scanner.Err() != nil { return scanner.Err() } - return fmt.Errorf("got: %s, want: %d", gots, pid) + return fmt.Errorf("got: %v, want: %d", gots, pid) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. @@ -106,7 +106,7 @@ func TestMemCGroup(t *testing.T) { time.Sleep(100 * time.Millisecond) } - t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20) + t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 126f0975a..22488b05d 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -46,7 +46,7 @@ func TestOOMScoreAdjSingle(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir ppid, err := specutils.GetParentPid(os.Getpid()) @@ -137,7 +137,7 @@ func TestOOMScoreAdjMulti(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir ppid, err := specutils.GetParentPid(os.Getpid()) diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go index 52f49b984..0ff69ab18 100644 --- a/test/runtimes/blacklist_test.go +++ b/test/runtimes/blacklist_test.go @@ -32,6 +32,6 @@ func TestBlacklists(t *testing.T) { t.Fatalf("error parsing blacklist: %v", err) } if *blacklistFile != "" && len(bl) == 0 { - t.Errorf("got empty blacklist for file %q", blacklistFile) + t.Errorf("got empty blacklist for file %q", *blacklistFile) } } diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go index ddb890dbc..3c98f4570 100644 --- a/test/runtimes/runner.go +++ b/test/runtimes/runner.go @@ -114,7 +114,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int F: func(t *testing.T) { // Is the test blacklisted? if _, ok := blacklist[tc]; ok { - t.Skip("SKIP: blacklisted test %q", tc) + t.Skipf("SKIP: blacklisted test %q", tc) } var ( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index d0c431234..d9095c95f 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -138,7 +138,6 @@ cc_library( hdrs = ["socket_netlink_route_util.h"], deps = [ ":socket_netlink_util", - "@com_google_absl//absl/types:optional", ], ) @@ -663,10 +662,7 @@ cc_binary( cc_binary( name = "exec_binary_test", testonly = 1, - srcs = select_arch( - amd64 = ["exec_binary.cc"], - arm64 = [], - ), + srcs = ["exec_binary.cc"], linkstatic = 1, deps = [ "//test/util:cleanup", @@ -2026,6 +2022,8 @@ cc_binary( "//test/util:file_descriptor", "@com_google_absl//absl/strings", gtest, + ":ip_socket_test_util", + ":unix_domain_socket_test_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -2802,13 +2800,13 @@ cc_binary( srcs = ["socket_netlink_route.cc"], linkstatic = 1, deps = [ + ":socket_netlink_route_util", ":socket_netlink_util", ":socket_test_util", "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", gtest, "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc index a33daff17..806d5729e 100644 --- a/test/syscalls/linux/aio.cc +++ b/test/syscalls/linux/aio.cc @@ -89,6 +89,7 @@ class AIOTest : public FileTest { FileTest::TearDown(); if (ctx_ != 0) { ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } } @@ -188,14 +189,19 @@ TEST_F(AIOTest, BadWrite) { } TEST_F(AIOTest, ExitWithPendingIo) { - // Setup a context that is 5 entries deep. - ASSERT_THAT(SetupContext(5), SyscallSucceeds()); + // Setup a context that is 100 entries deep. + ASSERT_THAT(SetupContext(100), SyscallSucceeds()); struct iocb cb = CreateCallback(); struct iocb* cbs[] = {&cb}; // Submit a request but don't complete it to make it pending. - EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + for (int i = 0; i < 100; ++i) { + EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + } + + ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } int Submitter(void* arg) { diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index a4f8f3cec..f57d38dc7 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -56,10 +56,6 @@ TEST(EpollTest, AllWritable) { struct epoll_event result[kFDsPerEpoll]; ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), SyscallSucceedsWithValue(kFDsPerEpoll)); - // TODO(edahlgren): Why do some tests check epoll_event::data, and others - // don't? Does Linux actually guarantee that, in any of these test cases, - // epoll_wait will necessarily write out the epoll_events in the order that - // they were registered? for (int i = 0; i < kFDsPerEpoll; i++) { ASSERT_EQ(result[i].events, EPOLLOUT); } diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index 07bd527e6..12c9b05ca 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -812,26 +812,28 @@ void ExecFromThread() { bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) { auto contents_or = GetContents("/proc/self/cmdline"); if (!contents_or.ok()) { - std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error(); + std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error() + << std::endl; return false; } auto contents = contents_or.ValueOrDie(); if (contents.back() != '\0') { - std::cerr << "Non-null terminated /proc/self/cmdline!"; + std::cerr << "Non-null terminated /proc/self/cmdline!" << std::endl; return false; } contents.pop_back(); std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0'); if (static_cast<int>(procfs_cmdline.size()) != argc) { - std::cerr << "argc = " << argc << " != " << procfs_cmdline.size(); + std::cerr << "argc = " << argc << " != " << procfs_cmdline.size() + << std::endl; return false; } for (int i = 0; i < argc; ++i) { if (procfs_cmdline[i] != argv[i]) { std::cerr << "Procfs command line argument " << i << " mismatch " - << procfs_cmdline[i] << " != " << argv[i]; + << procfs_cmdline[i] << " != " << argv[i] << std::endl; return false; } } diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 736452b0c..1a9f203b9 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -48,10 +48,17 @@ namespace { using ::testing::AnyOf; using ::testing::Eq; -#ifndef __x86_64__ +#if !defined(__x86_64__) && !defined(__aarch64__) // The assembly stub and ELF internal details must be ported to other arches. -#error "Test only supported on x86-64" -#endif // __x86_64__ +#error "Test only supported on x86-64/arm64" +#endif // __x86_64__ || __aarch64__ + +#if defined(__x86_64__) +#define EM_TYPE EM_X86_64 +#define IP_REG(p) ((p).rip) +#define RAX_REG(p) ((p).rax) +#define RDI_REG(p) ((p).rdi) +#define RETURN_REG(p) ((p).rax) // amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP. const char kPtraceCode[] = { @@ -139,6 +146,76 @@ const char kPtraceCode[] = { // Size of a syscall instruction. constexpr int kSyscallSize = 2; +#elif defined(__aarch64__) +#define EM_TYPE EM_AARCH64 +#define IP_REG(p) ((p).pc) +#define RAX_REG(p) ((p).regs[8]) +#define RDI_REG(p) ((p).regs[0]) +#define RETURN_REG(p) ((p).regs[0]) + +const char kPtraceCode[] = { + // MOVD $117, R8 /* ptrace */ + '\xa8', + '\x0e', + '\x80', + '\xd2', + // MOVD $0, R0 /* PTRACE_TRACEME */ + '\x00', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R1 /* pid */ + '\x01', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R2 /* addr */ + '\x02', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R3 /* data */ + '\x03', + '\x00', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $172, R8 /* getpid */ + '\x88', + '\x15', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $129, R8 /* kill, R0=pid */ + '\x28', + '\x10', + '\x80', + '\xd2', + // MOVD $19, R1 /* SIGSTOP */ + '\x61', + '\x02', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', +}; +// Size of a syscall instruction. +constexpr int kSyscallSize = 4; +#else +#error "Unknown architecture" +#endif + // This test suite tests executable loading in the kernel (ELF and interpreter // scripts). @@ -281,7 +358,7 @@ ElfBinary<64> StandardElf() { elf.header.e_ident[EI_DATA] = ELFDATA2LSB; elf.header.e_ident[EI_VERSION] = EV_CURRENT; elf.header.e_type = ET_EXEC; - elf.header.e_machine = EM_X86_64; + elf.header.e_machine = EM_TYPE; elf.header.e_version = EV_CURRENT; elf.header.e_phoff = sizeof(elf.header); elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr); @@ -327,9 +404,15 @@ TEST(ElfTest, Execute) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - // RIP is just beyond the final syscall instruction. - EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode)); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); + // RIP/PC is just beyond the final syscall instruction. + EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode)); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, @@ -718,9 +801,16 @@ TEST(ElfTest, PIE) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ // text page. @@ -787,9 +877,15 @@ TEST(ElfTest, PIENonZeroStart) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); // The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr. // @@ -910,9 +1006,15 @@ TEST(ElfTest, ELFInterpreter) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1084,9 +1186,15 @@ TEST(ElfTest, ELFInterpreterRelative) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1480,14 +1588,21 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // RIP is just beyond the final syscall instruction. Rewind to execute a brk // syscall. - regs.rip -= kSyscallSize; - regs.rax = __NR_brk; - regs.rdi = 0; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, ®s), SyscallSucceeds()); + IP_REG(regs) -= kSyscallSize; + RAX_REG(regs) = __NR_brk; + RDI_REG(regs) = 0; + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); // Resume the child, waiting for syscall entry. ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds()); @@ -1504,7 +1619,12 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) << "status = " << status; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // brk is after the text page. // @@ -1512,7 +1632,7 @@ TEST(ExecveTest, BrkAfterBinary) { // address will be, but it is always beyond the final page in the binary. // i.e., it does not start immediately after memsz in the middle of a page. // Userspace may expect to use that space. - EXPECT_GE(regs.rax, 0x41000); + EXPECT_GE(RETURN_REG(regs), 0x41000); } } // namespace diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 6f80bc97c..fb418e052 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -52,17 +52,6 @@ class FileTest : public ::testing::Test { test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE( Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR)); - // FIXME(edahlgren): enable when mknod syscall is supported. - // test_fifo_name_ = NewTempAbsPath(); - // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0, - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(), - // O_WRONLY), - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(), - // O_RDONLY), - // SyscallSucceeds()); - ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds()); ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds()); } @@ -96,18 +85,12 @@ class FileTest : public ::testing::Test { CloseFile(); UnlinkFile(); ClosePipes(); - - // FIXME(edahlgren): enable when mknod syscall is supported. - // close(test_fifo_[0]); - // close(test_fifo_[1]); - // unlink(test_fifo_name_.c_str()); } + protected: std::string test_file_name_; - std::string test_fifo_name_; FileDescriptor test_file_fd_; - int test_fifo_[2]; int test_pipe_[2]; }; diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc index ff8bdfeb0..853f6231a 100644 --- a/test/syscalls/linux/fork.cc +++ b/test/syscalls/linux/fork.cc @@ -431,7 +431,6 @@ TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) { << "status = " << status; } -#ifdef __x86_64__ // Clone with CLONE_SETTLS and a non-canonical TLS address is rejected. TEST(CloneTest, NonCanonicalTLS) { constexpr uintptr_t kNonCanonical = 1ull << 48; @@ -440,11 +439,25 @@ TEST(CloneTest, NonCanonicalTLS) { // on this. char stack; + // The raw system call interface on x86-64 is: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, int *child_tid, + // unsigned long tls); + // + // While on arm64, the order of the last two arguments is reversed: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, unsigned long tls, + // int *child_tid); +#if defined(__x86_64__) EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, nullptr, kNonCanonical), SyscallFailsWithErrno(EPERM)); -} +#elif defined(__aarch64__) + EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, + kNonCanonical, nullptr), + SyscallFailsWithErrno(EPERM)); #endif +} } // namespace } // namespace testing diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc index f97f60029..f87cdd7a1 100644 --- a/test/syscalls/linux/getrandom.cc +++ b/test/syscalls/linux/getrandom.cc @@ -29,6 +29,8 @@ namespace { #define SYS_getrandom 318 #elif defined(__i386__) #define SYS_getrandom 355 +#elif defined(__aarch64__) +#define SYS_getrandom 278 #else #error "Unknown architecture" #endif diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index bba022a41..98d07ae85 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -16,7 +16,6 @@ #include <net/if.h> #include <netinet/in.h> -#include <sys/ioctl.h> #include <sys/socket.h> #include <cstring> @@ -35,12 +34,11 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr) { } PosixErrorOr<int> InterfaceIndex(std::string name) { - // TODO(igudger): Consider using netlink. - ifreq req = {}; - memcpy(req.ifr_name, name.c_str(), name.size()); - ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0)); - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req)); - return req.ifr_ifindex; + int index = if_nametoindex(name.c_str()); + if (index) { + return index; + } + return PosixError(errno); } namespace { @@ -177,17 +175,17 @@ SocketKind IPv6TCPUnboundSocket(int type) { PosixError IfAddrHelper::Load() { Release(); RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); - return PosixError(0); + return NoError(); } void IfAddrHelper::Release() { if (ifaddr_) { freeifaddrs(ifaddr_); + ifaddr_ = nullptr; } - ifaddr_ = nullptr; } -std::vector<std::string> IfAddrHelper::InterfaceList(int family) { +std::vector<std::string> IfAddrHelper::InterfaceList(int family) const { std::vector<std::string> names; for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { @@ -198,7 +196,7 @@ std::vector<std::string> IfAddrHelper::InterfaceList(int family) { return names; } -sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { +const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const { for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { continue; @@ -210,7 +208,7 @@ sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { return nullptr; } -PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) { +PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const { return InterfaceIndex(name); } diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 39fd6709d..9c3859fcd 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -110,10 +110,10 @@ class IfAddrHelper { PosixError Load(); void Release(); - std::vector<std::string> InterfaceList(int family); + std::vector<std::string> InterfaceList(int family) const; - struct sockaddr* GetAddr(int family, std::string name); - PosixErrorOr<int> GetIndex(std::string name); + const sockaddr* GetAddr(int family, std::string name) const; + PosixErrorOr<int> GetIndex(std::string name) const; private: struct ifaddrs* ifaddr_; diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index 8b48f0804..dd981a278 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -246,7 +246,7 @@ int TestSIGPROFFairness(absl::Duration sleep) { // The number of samples on the main thread should be very low as it did // nothing. - TEST_CHECK(result.main_thread_samples < 60); + TEST_CHECK(result.main_thread_samples < 80); // Both workers should get roughly equal number of samples. TEST_CHECK(result.worker_samples.size() == 2); diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc index a8af8e545..6ce1e6cc3 100644 --- a/test/syscalls/linux/lseek.cc +++ b/test/syscalls/linux/lseek.cc @@ -53,7 +53,7 @@ TEST(LseekTest, NegativeOffset) { // A 32-bit off_t is not large enough to represent an offset larger than // maximum file size on standard file systems, so it isn't possible to cause // overflow. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST(LseekTest, Overflow) { // HA! Classic Linux. We really should have an EOVERFLOW // here, since we're seeking to something that cannot be diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index e57b49a4a..f8b7f7938 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -16,6 +16,7 @@ #include <fcntl.h> #include <linux/magic.h> #include <linux/memfd.h> +#include <linux/unistd.h> #include <string.h> #include <sys/mman.h> #include <sys/statfs.h> diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index def4c50a4..4036a9275 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -36,21 +36,12 @@ class MkdirTest : public ::testing::Test { // TearDown unlinks created files. void TearDown() override { - // FIXME(edahlgren): We don't currently implement rmdir. - // We do this unconditionally because there's no harm in trying. - rmdir(dirname_.c_str()); + EXPECT_THAT(rmdir(dirname_.c_str()), SyscallSucceeds()); } std::string dirname_; }; -TEST_F(MkdirTest, DISABLED_CanCreateReadbleDir) { - ASSERT_THAT(mkdir(dirname_.c_str(), 0444), SyscallSucceeds()); - ASSERT_THAT( - open(JoinPath(dirname_, "anything").c_str(), O_RDWR | O_CREAT, 0666), - SyscallFailsWithErrno(EACCES)); -} - TEST_F(MkdirTest, CanCreateWritableDir) { ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); std::string filename = JoinPath(dirname_, "anything"); @@ -84,10 +75,11 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - auto parent = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); - auto dir = JoinPath(parent.path(), "foo"); - ASSERT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES)); + ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds()); + auto dir = JoinPath(dirname_.c_str(), "foo"); + EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(open(JoinPath(dirname_, "file").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EACCES)); } } // namespace diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc index 367a90fe1..78ac96bed 100644 --- a/test/syscalls/linux/mlock.cc +++ b/test/syscalls/linux/mlock.cc @@ -199,8 +199,10 @@ TEST(MunlockallTest, Basic) { } #ifndef SYS_mlock2 -#ifdef __x86_64__ +#if defined(__x86_64__) #define SYS_mlock2 325 +#elif defined(__aarch64__) +#define SYS_mlock2 284 #endif #endif diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 11fb1b457..6d3227ab6 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -361,7 +361,7 @@ TEST_F(MMapTest, MapFixed) { } // 64-bit addresses work too -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST_F(MMapTest, MapFixed64) { EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0), @@ -571,6 +571,12 @@ const uint8_t machine_code[] = { 0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax 0xc3, // retq }; +#elif defined(__aarch64__) +const uint8_t machine_code[] = { + 0x40, 0x05, 0x80, 0x52, // mov w0, #42 + 0xc0, 0x03, 0x5f, 0xd6, // ret +}; +#endif // PROT_EXEC allows code execution TEST_F(MMapTest, ProtExec) { @@ -605,7 +611,6 @@ TEST_F(MMapTest, NoProtExecDeath) { EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), ""); } -#endif TEST_F(MMapTest, NoExceedLimitData) { void* prevbrk; @@ -1644,6 +1649,7 @@ TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) { } // Conditional on MAP_32BIT. +// This flag is supported only on x86-64, for 64-bit programs. #ifdef __x86_64__ TEST(MMapNoFixtureTest, Map32Bit) { diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 267ae19f6..640fe6bfc 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -186,6 +186,28 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW)); } +// Test that open(2) can follow symlinks that point back to the same tree. +// Test sets up files as follows: +// root/child/symlink => redirects to ../.. +// root/child/target => regular file +// +// open("root/child/symlink/root/child/file") +TEST_F(OpenTest, SymlinkRecurse) { + auto root = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); + auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); + auto symlink = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(child.path(), "../..")); + auto target = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(child.path(), "abc", 0644)); + auto path_via_symlink = + JoinPath(symlink.path(), Basename(root.path()), Basename(child.path()), + Basename(target.path())); + const auto contents = + ASSERT_NO_ERRNO_AND_VALUE(GetContents(path_via_symlink)); + ASSERT_EQ(contents, "abc"); +} + TEST_F(OpenTest, Fault) { char* totally_not_null = nullptr; ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT)); diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index d8e19e910..67228b66b 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -265,6 +265,8 @@ TEST_P(PipeTest, OffsetCalls) { SyscallFailsWithErrno(ESPIPE)); struct iovec iov; + iov.iov_base = &buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); } diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc index c42472474..1e35a4a8b 100644 --- a/test/syscalls/linux/poll.cc +++ b/test/syscalls/linux/poll.cc @@ -266,7 +266,7 @@ TEST_F(PollTest, Nfds) { } rlim_t max_fds = rlim.rlim_cur; - std::cout << "Using limit: " << max_fds; + std::cout << "Using limit: " << max_fds << std::endl; // Create an eventfd. Since its value is initially zero, it is writable. FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index 2cecf2e5f..bcdbbb044 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/mman.h> #include <sys/socket.h> #include <sys/types.h> @@ -118,6 +119,21 @@ TEST_F(Pread64Test, EndOfFile) { EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST_F(Pread64Test, Overflow) { + int f = memfd_create("negative", 0); + const FileDescriptor fd(f); + + EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds()); + + char buf[10]; + EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); +} + TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) { int sock_fds[2]; EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds()); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 5a70f6c3b..79a625ebc 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -994,7 +994,7 @@ constexpr uint64_t kMappingSize = 100 << 20; // Tolerance on RSS comparisons to account for background thread mappings, // reclaimed pages, newly faulted pages, etc. -constexpr uint64_t kRSSTolerance = 5 << 20; +constexpr uint64_t kRSSTolerance = 10 << 20; // Capture RSS before and after an anonymous mapping with passed prot. void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) { @@ -1326,8 +1326,6 @@ TEST(ProcPidSymlink, SubprocessRunning) { SyscallSucceedsWithValue(sizeof(buf))); } -// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux -// on proc files. TEST(ProcPidSymlink, SubprocessZombied) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); @@ -1337,7 +1335,7 @@ TEST(ProcPidSymlink, SubprocessZombied) { int want = EACCES; if (!IsRunningOnGvisor()) { auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - if (version.major == 4 && version.minor > 3) { + if (version.major > 4 || (version.major == 4 && version.minor > 3)) { want = ENOENT; } } @@ -1350,30 +1348,25 @@ TEST(ProcPidSymlink, SubprocessZombied) { SyscallFailsWithErrno(want)); } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc + // files. // // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. + // 4.17: Syscall succeeds and returns 1. // - // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + if (!IsRunningOnGvisor()) { + return; + } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // - // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. - // - // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); + + EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); } // Test whether /proc/PID/ symlinks can be read for an exited process. TEST(ProcPidSymlink, SubprocessExited) { - // FIXME(gvisor.dev/issue/164): These all succeed on gVisor. - SKIP_IF(IsRunningOnGvisor()); - char buf[1]; EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)), diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 4e23d1e78..cac394910 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -353,7 +353,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { EXPECT_EQ(oldNoPorts, newNoPorts - 1); } -TEST(ProcNetSnmp, UdpIn) { +TEST(ProcNetSnmp, UdpIn_NoRandomSave) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. const DisableSave ds; diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 66db0acaa..a63067586 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -106,7 +106,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { std::vector<UnixEntry> entries; std::vector<std::string> lines = absl::StrSplit(content, '\n'); std::cerr << "<contents of /proc/net/unix>" << std::endl; - for (std::string line : lines) { + for (const std::string& line : lines) { // Emit the proc entry to the test output to provide context for the test // results. std::cerr << line << std::endl; @@ -374,7 +374,7 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { // corresponding entries, as they don't have an address yet. if (IsRunningOnGvisor()) { ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { + for (const auto& e : entries) { ASSERT_EQ(e.state, SS_DISCONNECTING); } } @@ -403,7 +403,7 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { // corresponding entries, as they don't have an address yet. if (IsRunningOnGvisor()) { ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { + for (const auto& e : entries) { ASSERT_EQ(e.state, SS_DISCONNECTING); } } diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc index 7f2e8f203..9fb1b3a2c 100644 --- a/test/syscalls/linux/proc_pid_smaps.cc +++ b/test/syscalls/linux/proc_pid_smaps.cc @@ -173,7 +173,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( return; } unknown_fields.insert(std::string(key)); - std::cerr << "skipping unknown smaps field " << key; + std::cerr << "skipping unknown smaps field " << key << std::endl; }; auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty()); @@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( // amount of whitespace). if (!entry) { std::cerr << "smaps line not considered a maps line: " - << maybe_maps_entry.error_message(); + << maybe_maps_entry.error_message() << std::endl; return PosixError( EINVAL, absl::StrCat("smaps field line without preceding maps line: ", l)); diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index bfe3e2603..926690eb8 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -400,9 +400,11 @@ TEST(PtraceTest, GetRegSet) { // Read exactly the full register set. EXPECT_EQ(iov.iov_len, sizeof(regs)); -#ifdef __x86_64__ +#if defined(__x86_64__) // Child called kill(2), with SIGSTOP as arg 2. EXPECT_EQ(regs.rsi, SIGSTOP); +#elif defined(__aarch64__) + EXPECT_EQ(regs.regs[1], SIGSTOP); #endif // Suppress SIGSTOP and resume the child. @@ -752,15 +754,23 @@ TEST(PtraceTest, SyscallSucceeds()); EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80)) << "si_code = " << siginfo.si_code; -#ifdef __x86_64__ + { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone) << "orig_rax = " << regs.orig_rax; EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8]; + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // After this point, the child will be making wait4 syscalls that will be // interrupted by saving, so saving is not permitted. Note that this is @@ -805,14 +815,21 @@ TEST(PtraceTest, SyscallSucceedsWithValue(child_pid)); EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) << " status " << status; -#ifdef __x86_64__ { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_EQ(SYS_wait4, regs.orig_rax); EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_EQ(SYS_wait4, regs.regs[8]); + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // Detach from the child and wait for it to exit. ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); @@ -1188,7 +1205,7 @@ TEST(PtraceTest, SeizeSetOptions) { // gVisor is not susceptible to this race because // kernel.Task.waitCollectTraceeStopLocked() checks specifically for an // active ptraceStop, which is not initiated if SIGKILL is pending. - std::cout << "Observed syscall-exit after SIGKILL"; + std::cout << "Observed syscall-exit after SIGKILL" << std::endl; ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceedsWithValue(child_pid)); } diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index dafe64d20..b8a0159ba 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -1126,7 +1126,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { std::string kExpected = "GO\nBLUE\n!"; // Write each line. - for (std::string input : kInputs) { + for (const std::string& input : kInputs) { ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()), SyscallSucceedsWithValue(input.size())); } diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index b48fe540d..e69794910 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> @@ -27,14 +28,7 @@ namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is not incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class Pwrite64 : public ::testing::Test { void SetUp() override { name_ = NewTempAbsPath(); @@ -72,6 +66,17 @@ TEST_F(Pwrite64, InvalidArgs) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(Pwrite64, Overflow) { + int fd; + ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds()); + constexpr int64_t kBufSize = 1024; + std::vector<char> buf(kBufSize); + std::fill(buf.begin(), buf.end(), 'a'); + EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD index ee5b0a11b..853258b04 100644 --- a/test/syscalls/linux/rseq/BUILD +++ b/test/syscalls/linux/rseq/BUILD @@ -21,23 +21,23 @@ genrule( ], outs = ["rseq"], cmd = "$(CC) " + - "$(CC_FLAGS) " + - "-I. " + - "-Wall " + - "-Werror " + - "-O2 " + - "-std=c++17 " + - "-static " + - "-nostdlib " + - "-ffreestanding " + - "-o " + - "$(location rseq) " + - select_arch( - amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ", - arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ", - no_match_error = "unsupported architecture", - ) + - "$(location rseq.cc)", + "$(CC_FLAGS) " + + "-I. " + + "-Wall " + + "-Werror " + + "-O2 " + + "-std=c++17 " + + "-static " + + "-nostdlib " + + "-ffreestanding " + + "-o " + + "$(location rseq) " + + select_arch( + amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ", + arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ", + no_match_error = "unsupported architecture", + ) + + "$(location rseq.cc)", toolchains = [ cc_toolchain, ":no_pie_cc_flags", diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 580ab5193..64123e904 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -70,6 +71,28 @@ TEST(SendFileTest, InvalidOffset) { SyscallFailsWithErrno(EINVAL)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST(SendFileTest, Overflow) { + // Create input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Open the output file. + int fd; + EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds()); + const FileDescriptor outf(fd); + + // out_offset + kSize overflows INT64_MAX. + loff_t out_offset = 0x7ffffffffffffffeull; + constexpr int kSize = 3; + EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize), + SyscallFailsWithErrno(EINVAL)); +} + TEST(SendFileTest, SendTrivially) { // Create temp files. constexpr char kData[] = "To be, or not to be, that is the question:"; @@ -530,6 +553,34 @@ TEST(SendFileTest, SendToSpecialFile) { SyscallSucceedsWithValue(kSize & (~7))); } +TEST(SendFileTest, SendFileToPipe) { + // Create temp file. + constexpr char kData[] = "<insert-quote-here>"; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Create a pipe for sending to a pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Expect to read up to the given size. + std::vector<char> buf(kDataSize); + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kDataSize)); + }); + + // Send with twice the size of the file, which should hit EOF. + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize * 2), + SyscallSucceedsWithValue(kDataSize)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc index 8f7ee4163..c101fe9d2 100644 --- a/test/syscalls/linux/sendfile_socket.cc +++ b/test/syscalls/linux/sendfile_socket.cc @@ -23,6 +23,7 @@ #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" @@ -35,61 +36,39 @@ namespace { class SendFileTest : public ::testing::TestWithParam<int> { protected: - PosixErrorOr<std::tuple<int, int>> Sockets() { + PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) { // Bind a server socket. int family = GetParam(); - struct sockaddr server_addr = {}; switch (family) { case AF_INET: { - struct sockaddr_in* server_addr_in = - reinterpret_cast<struct sockaddr_in*>(&server_addr); - server_addr_in->sin_family = family; - server_addr_in->sin_addr.s_addr = INADDR_ANY; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "TCP", AF_INET, type, 0, + TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } else { + return SocketPairKind{ + "UDP", AF_INET, type, 0, + UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } } case AF_UNIX: { - struct sockaddr_un* server_addr_un = - reinterpret_cast<struct sockaddr_un*>(&server_addr); - server_addr_un->sun_family = family; - server_addr_un->sun_path[0] = '\0'; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } else { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } } default: return PosixError(EINVAL); } - int server = socket(family, SOCK_STREAM, 0); - if (bind(server, &server_addr, sizeof(server_addr)) < 0) { - return PosixError(errno); - } - if (listen(server, 1) < 0) { - close(server); - return PosixError(errno); - } - - // Fetch the address; both are anonymous. - socklen_t length = sizeof(server_addr); - if (getsockname(server, &server_addr, &length) < 0) { - close(server); - return PosixError(errno); - } - - // Connect the client. - int client = socket(family, SOCK_STREAM, 0); - if (connect(client, &server_addr, length) < 0) { - close(server); - close(client); - return PosixError(errno); - } - - // Accept on the server. - int server_client = accept(server, nullptr, 0); - if (server_client < 0) { - close(server); - close(client); - return PosixError(errno); - } - close(server); - return std::make_tuple(client, server_client); } }; @@ -106,9 +85,7 @@ TEST_P(SendFileTest, SendMultiple) { const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); // Create sockets. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor server(std::get<0>(fds)); - FileDescriptor client(std::get<1>(fds)); // non-const, reset is used. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // Thread that reads data from socket and dumps to a file. ScopedThread th([&] { @@ -118,7 +95,7 @@ TEST_P(SendFileTest, SendMultiple) { // Read until socket is closed. char buf[10240]; for (int cnt = 0;; cnt++) { - int r = RetryEINTR(read)(server.get(), buf, sizeof(buf)); + int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf)); // We cannot afford to save on every read() call. if (cnt % 1000 == 0) { ASSERT_THAT(r, SyscallSucceeds()); @@ -149,10 +126,10 @@ TEST_P(SendFileTest, SendMultiple) { for (size_t sent = 0; sent < data.size(); cnt++) { const size_t remain = data.size() - sent; std::cout << "sendfile, size=" << data.size() << ", sent=" << sent - << ", remain=" << remain; + << ", remain=" << remain << std::endl; // Send data and verify that sendfile returns the correct value. - int res = sendfile(client.get(), inf.get(), nullptr, remain); + int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain); // We cannot afford to save on every sendfile() call. if (cnt % 120 == 0) { MaybeSave(); @@ -169,7 +146,7 @@ TEST_P(SendFileTest, SendMultiple) { } // Close socket to stop thread. - client.reset(); + close(socks->release_second_fd()); th.Join(); // Verify that the output file has the correct data. @@ -183,9 +160,7 @@ TEST_P(SendFileTest, SendMultiple) { TEST_P(SendFileTest, Shutdown) { // Create a socket. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, reset below. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // If this is a TCP socket, then turn off linger. if (GetParam() == AF_INET) { @@ -193,7 +168,7 @@ TEST_P(SendFileTest, Shutdown) { sl.l_onoff = 1; sl.l_linger = 0; ASSERT_THAT( - setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), SyscallSucceeds()); } @@ -212,12 +187,12 @@ TEST_P(SendFileTest, Shutdown) { ScopedThread t([&]() { size_t done = 0; while (done < data.size()) { - int n = RetryEINTR(read)(server.get(), data.data(), data.size()); + int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size()); ASSERT_THAT(n, SyscallSucceeds()); done += n; } // Close the server side socket. - server.reset(); + close(socks->release_first_fd()); }); // Continuously stream from the file to the socket. Note we do not assert @@ -225,7 +200,7 @@ TEST_P(SendFileTest, Shutdown) { // data is written. Eventually, we should get a connection reset error. while (1) { off_t offset = 0; // Always read from the start. - int n = sendfile(client.get(), inf.get(), &offset, data.size()); + int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size()); EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET), SyscallFailsWithErrno(EPIPE), SyscallSucceeds())); if (n <= 0) { @@ -234,6 +209,20 @@ TEST_P(SendFileTest, Shutdown) { } } +TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) { + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM)); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + // The value to the count argument has to be so that it is impossible to + // allocate a buffer of this size. In Linux, sendfile transfer at most + // 0x7ffff000 (MAX_RW_COUNT) bytes. + EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004), + SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest, ::testing::Values(AF_UNIX, AF_INET)); diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index b24618a88..d3000dbc6 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -234,7 +234,7 @@ TEST_P(DualStackSocketTest, AddressOperations) { } } -// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny. +// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny. INSTANTIATE_TEST_SUITE_P( All, DualStackSocketTest, ::testing::Combine( @@ -319,17 +319,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { tcpSimpleConnectTest(listener, connector, false); } -TEST_P(SocketInetLoopbackTest, TCPListenClose) { +TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; - constexpr int kAcceptCount = 32; - constexpr int kBacklog = kAcceptCount * 2; - constexpr int kFDs = 128; - constexpr int kThreadCount = 4; - constexpr int kFDsPerThread = kFDs / kThreadCount; + constexpr int kBacklog = 2; + constexpr int kFDs = kBacklog + 1; // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( @@ -348,39 +345,167 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - DisableSave ds; // Too many system calls. sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - FileDescriptor clients[kFDs]; - std::unique_ptr<ScopedThread> threads[kThreadCount]; + + // Shutdown the write of the listener, expect to not have any effect. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds()); + for (int i = 0; i < kFDs; i++) { - clients[i] = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); } - for (int i = 0; i < kThreadCount; i++) { - threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr, - &clients, i]() { - for (int j = 0; j < kFDsPerThread; j++) { - int k = i * kFDsPerThread + j; - int ret = - connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } - } - }); + + // Shutdown the read of the listener, expect to fail subsequent + // server accepts, binds and client connects. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EINVAL)); + + // Check that shutdown did not release the port. + FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Check that subsequent connection attempts receive a RST. + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallFailsWithErrno(ECONNREFUSED)); } - for (int i = 0; i < kThreadCount; i++) { - threads[i]->Join(); +} + +TEST_P(SocketInetLoopbackTest, TCPListenClose) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kAcceptCount = 2; + constexpr int kBacklog = kAcceptCount + 2; + constexpr int kFDs = kBacklog * 3; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + clients.push_back(std::move(client)); } for (int i = 0; i < kAcceptCount; i++) { auto accepted = ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); } - // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked - // before function end. - // ds.reset(); +} + +void TestListenWhileConnect(const TestParam& param, + void (*stopListen)(FileDescriptor&)) { + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kClients = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + clients.push_back(std::move(client)); + } + } + + stopListen(listen_fd); + + for (auto& client : clients) { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = client.get(), + .events = POLLIN, + }; + // When the listening socket is closed, then we expect the remote to reset + // the connection. + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + char c; + // Subsequent read can fail with: + // ECONNRESET: If the client connection was established and was reset by the + // remote. + // ECONNREFUSED: If the client connection failed to be established. + ASSERT_THAT(read(client.get(), &c, sizeof(c)), + AnyOf(SyscallFailsWithErrno(ECONNRESET), + SyscallFailsWithErrno(ECONNREFUSED))); + } +} + +TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(close(f.release()), SyscallSucceeds()); + }); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); + }); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -605,15 +730,23 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { &conn_addrlen), SyscallSucceeds()); - constexpr int kTCPLingerTimeout = 5; - EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, - &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), - SyscallSucceedsWithValue(0)); + // Disable cooperative saves after this point as TCP timers are not restored + // across a S/R. + { + DisableSave ds; + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); - // close the connecting FD to trigger FIN_WAIT2 on the connected fd. - conn_fd.reset(); + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + // ds going out of scope will Re-enable S/R's since at this point the timer + // must have fired and cleaned up the endpoint. + } // Now bind and connect a new socket and verify that we can immediately // rebind the address bound by the conn_fd as it never entered TIME_WAIT. @@ -1082,6 +1215,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { if (connects_received >= kConnectAttempts) { // Another thread have shutdown our read side causing the // accept to fail. + ASSERT_EQ(errno, EINVAL); break; } ASSERT_NO_ERRNO(fd); @@ -1149,7 +1283,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -1262,7 +1396,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -2138,8 +2272,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(connect(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. @@ -2204,7 +2339,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)), SyscallSucceeds()); - std::cout << portreuse1 << " " << portreuse2; + std::cout << portreuse1 << " " << portreuse2 << std::endl; int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen); // Verify that two sockets can be bound to the same port only if diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 40e673625..d690d9564 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -45,37 +45,31 @@ void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { got_if_infos_ = false; // Get interface list. - std::vector<std::string> if_names; ASSERT_NO_ERRNO(if_helper_.Load()); - if_names = if_helper_.InterfaceList(AF_INET); + std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET); if (if_names.size() != 2) { return; } // Figure out which interface is where. - int lo = 0, eth = 1; - if (if_names[lo] != "lo") { - lo = 1; - eth = 0; - } - - if (if_names[lo] != "lo") { - return; - } - - lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo])); - lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]); - if (lo_if_addr_ == nullptr) { + std::string lo = if_names[0]; + std::string eth = if_names[1]; + if (lo != "lo") std::swap(lo, eth); + if (lo != "lo") return; + + lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo)); + auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo); + if (lo_if_addr == nullptr) { return; } - lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr; + lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr); - eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth])); - eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]); - if (eth_if_addr_ == nullptr) { + eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth)); + auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth); + if (eth_if_addr == nullptr) { return; } - eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr; + eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr); got_if_infos_ = true; } @@ -242,7 +236,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the non-receiving socket to the unicast ethernet address. auto norecv_addr = rcv1_addr; reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = - eth_if_sin_addr_; + eth_if_addr_.sin_addr; ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), norecv_addr.addr_len), SyscallSucceedsWithValue(0)); @@ -1028,7 +1022,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; iface.imr_ifindex = lo_if_idx_; - iface.imr_address = eth_if_sin_addr_; + iface.imr_address = eth_if_addr_.sin_addr; ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, sizeof(iface)), SyscallSucceeds()); @@ -1058,7 +1052,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SKIP_IF(IsRunningOnGvisor()); // Verify the received source address. - EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr); + EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr); } // Check that when we are bound to one interface we can set IP_MULTICAST_IF to @@ -1075,7 +1069,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Create sender and bind to eth interface. auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)), + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(ð_if_addr_), + sizeof(eth_if_addr_)), SyscallSucceeds()); // Run through all possible combinations of index and address for @@ -1085,9 +1080,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, struct in_addr imr_address; } test_data[] = { {lo_if_idx_, {}}, - {0, lo_if_sin_addr_}, - {lo_if_idx_, lo_if_sin_addr_}, - {lo_if_idx_, eth_if_sin_addr_}, + {0, lo_if_addr_.sin_addr}, + {lo_if_idx_, lo_if_addr_.sin_addr}, + {lo_if_idx_, eth_if_addr_.sin_addr}, }; for (auto t : test_data) { ip_mreqn iface = {}; diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h index bec2e96ee..10b90b1e0 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h @@ -36,10 +36,8 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest { // Interface infos. int lo_if_idx_; int eth_if_idx_; - sockaddr* lo_if_addr_; - sockaddr* eth_if_addr_; - in_addr lo_if_sin_addr_; - in_addr eth_if_sin_addr_; + sockaddr_in lo_if_addr_; + sockaddr_in eth_if_addr_; }; } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index e5aed1eec..fbe61c5a0 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -26,7 +26,7 @@ #include "gtest/gtest.h" #include "absl/strings/str_format.h" -#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/capability_util.h" @@ -118,24 +118,6 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { // TODO(mpratt): Check ifinfomsg contents and following attrs. } -PosixError DumpLinks( - const FileDescriptor& fd, uint32_t seq, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = seq; - req.ifm.ifi_family = AF_UNSPEC; - - return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); -} - TEST(NetlinkRouteTest, GetLinkDump) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -152,7 +134,7 @@ TEST(NetlinkRouteTest, GetLinkDump) { const struct ifinfomsg* msg = reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); std::cout << "Found interface idx=" << msg->ifi_index - << ", type=" << std::hex << msg->ifi_type; + << ", type=" << std::hex << msg->ifi_type << std::endl; if (msg->ifi_type == ARPHRD_LOOPBACK) { loopbackFound = true; EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); @@ -161,37 +143,6 @@ TEST(NetlinkRouteTest, GetLinkDump) { EXPECT_TRUE(loopbackFound); } -struct Link { - int index; - std::string name; -}; - -PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - absl::optional<Link> link; - RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { - if (hdr->nlmsg_type != RTM_NEWLINK || - hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { - return; - } - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - if (msg->ifi_type == ARPHRD_LOOPBACK) { - const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); - if (rta == nullptr) { - // Ignore links that do not have a name. - return; - } - - link = Link(); - link->index = msg->ifi_index; - link->name = std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); - } - })); - return link; -} - // CheckLinkMsg checks a netlink message against an expected link. void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); @@ -209,9 +160,7 @@ void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { } TEST(NetlinkRouteTest, GetLinkByIndex) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -227,13 +176,13 @@ TEST(NetlinkRouteTest, GetLinkByIndex) { req.hdr.nlmsg_flags = NLM_F_REQUEST; req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; - req.ifm.ifi_index = loopback_link->index; + req.ifm.ifi_index = loopback_link.index; bool found = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); + CheckLinkMsg(hdr, loopback_link); found = true; }, false)); @@ -241,9 +190,7 @@ TEST(NetlinkRouteTest, GetLinkByIndex) { } TEST(NetlinkRouteTest, GetLinkByName) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -262,8 +209,8 @@ TEST(NetlinkRouteTest, GetLinkByName) { req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; req.rtattr.rta_type = IFLA_IFNAME; - req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1); - strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname)); + req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1); + strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname)); req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); @@ -271,7 +218,7 @@ TEST(NetlinkRouteTest, GetLinkByName) { ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); + CheckLinkMsg(hdr, loopback_link); found = true; }, false)); @@ -523,9 +470,7 @@ TEST(NetlinkRouteTest, LookupAll) { TEST(NetlinkRouteTest, AddAddr) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -545,7 +490,7 @@ TEST(NetlinkRouteTest, AddAddr) { req.ifa.ifa_prefixlen = 24; req.ifa.ifa_flags = 0; req.ifa.ifa_scope = 0; - req.ifa.ifa_index = loopback_link->index; + req.ifa.ifa_index = loopback_link.index; req.rtattr.rta_type = IFA_LOCAL; req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); inet_pton(AF_INET, "10.0.0.1", &req.addr); diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index 53eb3b6b2..bde1dbb4d 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -18,7 +18,6 @@ #include <linux/netlink.h> #include <linux/rtnetlink.h> -#include "absl/types/optional.h" #include "test/syscalls/linux/socket_netlink_util.h" namespace gvisor { @@ -73,14 +72,14 @@ PosixErrorOr<std::vector<Link>> DumpLinks() { return links; } -PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { +PosixErrorOr<Link> LoopbackLink() { ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); for (const auto& link : links) { if (link.type == ARPHRD_LOOPBACK) { - return absl::optional<Link>(link); + return link; } } - return absl::optional<Link>(); + return PosixError(ENOENT, "loopback link not found"); } PosixError LinkAddLocalAddr(int index, int family, int prefixlen, diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index 2c018e487..149c4a7f6 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -20,7 +20,6 @@ #include <vector> -#include "absl/types/optional.h" #include "test/syscalls/linux/socket_netlink_util.h" namespace gvisor { @@ -37,7 +36,8 @@ PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, PosixErrorOr<std::vector<Link>> DumpLinks(); -PosixErrorOr<absl::optional<Link>> FindLoopbackLink(); +// Returns the loopback link on the system. ENOENT if not found. +PosixErrorOr<Link> LoopbackLink(); // LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. PosixError LinkAddLocalAddr(int index, int family, int prefixlen, diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 5d3a39868..53b678e94 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -364,11 +364,6 @@ CreateTCPConnectAcceptSocketPair(int bound, int connected, int type, } MaybeSave(); // Successful accept. - // FIXME(b/110484944) - if (connect_result == -1) { - absl::SleepFor(absl::Seconds(1)); - } - T extra_addr = {}; LocalhostAddr(&extra_addr, dual_stack); return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr, diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 4cf1f76f1..8bf663e8b 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -257,6 +257,8 @@ TEST_P(UnixSocketPairTest, ShutdownWrite) { TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) { // TODO(b/122310852): We should be returning ENXIO and NOT EIO. + // TODO(github.dev/issue/1624): This should be resolved in VFS2. Verify + // that this is the case and delete the SKIP_IF once we delete VFS1. SKIP_IF(IsRunningOnGvisor()); auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index faa1247f6..f103e2e56 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/resource.h> #include <sys/sendfile.h> diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc index 53ad2dda3..6195b11e1 100644 --- a/test/syscalls/linux/tuntap.cc +++ b/test/syscalls/linux/tuntap.cc @@ -56,14 +56,14 @@ PosixErrorOr<std::set<std::string>> DumpLinkNames() { return names; } -PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) { +PosixErrorOr<Link> GetLinkByName(const std::string& name) { ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); for (const auto& link : links) { if (link.name == name) { - return absl::optional<Link>(link); + return link; } } - return absl::optional<Link>(); + return PosixError(ENOENT, "interface not found"); } struct pihdr { @@ -242,7 +242,7 @@ TEST_F(TuntapTest, InvalidReadWrite) { TEST_F(TuntapTest, WriteToDownDevice) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - // FIXME: gVisor always creates enabled/up'd interfaces. + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. SKIP_IF(IsRunningOnGvisor()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); @@ -268,23 +268,21 @@ PosixErrorOr<FileDescriptor> OpenAndAttachTap( return PosixError(errno); } - ASSIGN_OR_RETURN_ERRNO(absl::optional<Link> link, GetLinkByName(dev_name)); - if (!link.has_value()) { - return PosixError(ENOENT, "no link"); - } + ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name)); // Interface setup. struct in_addr addr; inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr); - EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24, - &addr, sizeof(addr))); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr, + sizeof(addr))); if (!IsRunningOnGvisor()) { - // FIXME: gVisor doesn't support setting MAC address on interfaces yet. - RETURN_IF_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA))); + // FIXME(b/110961832): gVisor doesn't support setting MAC address on + // interfaces yet. + RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA))); - // FIXME: gVisor always creates enabled/up'd interfaces. - RETURN_IF_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP)); + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. + RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP)); } return fd; diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index 6218fbce1..ff66a79f4 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <grp.h> +#include <sys/resource.h> #include <sys/types.h> #include <unistd.h> @@ -249,6 +250,17 @@ TEST(UidGidRootTest, Setgroups) { SyscallFailsWithErrno(EFAULT)); } +TEST(UidGidRootTest, Setuid_prlimit) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); + + // Change our UID. + EXPECT_THAT(seteuid(65534), SyscallSucceeds()); + + // Despite the UID change, we should be able to get our own limits. + struct rlimit rl = {}; + ASSERT_THAT(prlimit(0, RLIMIT_NOFILE, NULL, &rl), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc index 3a927a430..22e6d1a85 100644 --- a/test/syscalls/linux/utimes.cc +++ b/test/syscalls/linux/utimes.cc @@ -34,17 +34,10 @@ namespace testing { namespace { -// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the -// application's time domain, so when asserting that times are within a window, -// we expand the window to allow for differences between the time domains. -constexpr absl::Duration kClockSlack = absl::Milliseconds(100); - // TimeBoxed runs fn, setting before and after to (coarse realtime) times // guaranteed* to come before and after fn started and completed, respectively. // // fn may be called more than once if the clock is adjusted. -// -// * See the comment on kClockSlack. gVisor breaks this guarantee. void TimeBoxed(absl::Time* before, absl::Time* after, std::function<void()> const& fn) { do { @@ -69,12 +62,6 @@ void TimeBoxed(absl::Time* before, absl::Time* after, // which could lead to test failures, but that is very unlikely to happen. continue; } - - if (IsRunningOnGvisor()) { - // See comment on kClockSlack. - *before -= kClockSlack; - *after += kClockSlack; - } } while (*after < *before); } @@ -235,10 +222,7 @@ void TestUtimensat(int dirFd, std::string const& path) { EXPECT_GE(mtime3, before); EXPECT_LE(mtime3, after); - if (!IsRunningOnGvisor()) { - // FIXME(b/36516566): Gofers set atime and mtime to different "now" times. - EXPECT_EQ(atime3, mtime3); - } + EXPECT_EQ(atime3, mtime3); } TEST(UtimensatTest, OnAbsPath) { diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 9b219cfd6..39b5b2f56 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -31,14 +31,8 @@ namespace gvisor { namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. + +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class WriteTest : public ::testing::Test { public: ssize_t WriteBytes(int fd, int bytes) { diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index 8b00ef44c..3231732ec 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -41,12 +41,12 @@ class XattrTest : public FileTest {}; TEST_F(XattrTest, XattrNonexistentFile) { const char* path = "/does/not/exist"; - EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(getxattr(path, nullptr, nullptr, 0), + const char* name = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT)); EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT)); } TEST_F(XattrTest, XattrNullName) { diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc index 9fee52fbb..a1b994c45 100644 --- a/test/util/capability_util.cc +++ b/test/util/capability_util.cc @@ -63,13 +63,13 @@ PosixErrorOr<bool> CanCreateUserNamespace() { // is in a chroot environment (i.e., the caller's root directory does // not match the root directory of the mount namespace in which it // resides)." - std::cerr << "clone(CLONE_NEWUSER) failed with EPERM"; + std::cerr << "clone(CLONE_NEWUSER) failed with EPERM" << std::endl; return false; } else if (errno == EUSERS) { // "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call // would cause the limit on the number of nested user namespaces to be // exceeded. See user_namespaces(7)." - std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS"; + std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS" << std::endl; return false; } else { // Unexpected error code; indicate an actual error. |