summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact')
-rw-r--r--test/packetimpact/dut/posix_server.cc56
-rw-r--r--test/packetimpact/proto/posix_server.proto87
-rw-r--r--test/packetimpact/testbench/BUILD10
-rw-r--r--test/packetimpact/testbench/connections.go99
-rw-r--r--test/packetimpact/testbench/dut.go351
-rw-r--r--test/packetimpact/testbench/layers.go158
-rw-r--r--test/packetimpact/testbench/layers_test.go144
-rw-r--r--test/packetimpact/testbench/rawsockets.go6
-rw-r--r--test/packetimpact/tests/BUILD24
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go12
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go37
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go58
12 files changed, 762 insertions, 280 deletions
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
index 4a71c54c6..86e580c6f 100644
--- a/test/packetimpact/dut/posix_server.cc
+++ b/test/packetimpact/dut/posix_server.cc
@@ -61,13 +61,15 @@
}
class PosixImpl final : public posix_server::Posix::Service {
- ::grpc::Status Socket(grpc_impl::ServerContext *context,
- const ::posix_server::SocketRequest *request,
- ::posix_server::SocketResponse *response) override {
- response->set_fd(
- socket(request->domain(), request->type(), request->protocol()));
+ ::grpc::Status Accept(grpc_impl::ServerContext *context,
+ const ::posix_server::AcceptRequest *request,
+ ::posix_server::AcceptResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_fd(accept(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), &addrlen));
response->set_errno_(errno);
- return ::grpc::Status::OK;
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
}
::grpc::Status Bind(grpc_impl::ServerContext *context,
@@ -119,6 +121,14 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
+ ::grpc::Status Close(grpc_impl::ServerContext *context,
+ const ::posix_server::CloseRequest *request,
+ ::posix_server::CloseResponse *response) override {
+ response->set_ret(close(request->fd()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status GetSockName(
grpc_impl::ServerContext *context,
const ::posix_server::GetSockNameRequest *request,
@@ -139,15 +149,13 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
- ::grpc::Status Accept(grpc_impl::ServerContext *context,
- const ::posix_server::AcceptRequest *request,
- ::posix_server::AcceptResponse *response) override {
- sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- response->set_fd(accept(request->sockfd(),
- reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ ::grpc::Status Send(::grpc::ServerContext *context,
+ const ::posix_server::SendRequest *request,
+ ::posix_server::SendResponse *response) override {
+ response->set_ret(::send(request->sockfd(), request->buf().data(),
+ request->buf().size(), request->flags()));
response->set_errno_(errno);
- return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ return ::grpc::Status::OK;
}
::grpc::Status SetSockOpt(
@@ -161,6 +169,17 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
+ ::grpc::Status SetSockOptInt(
+ ::grpc::ServerContext *context,
+ const ::posix_server::SetSockOptIntRequest *request,
+ ::posix_server::SetSockOptIntResponse *response) override {
+ int opt = request->intval();
+ response->set_ret(::setsockopt(request->sockfd(), request->level(),
+ request->optname(), &opt, sizeof(opt)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status SetSockOptTimeval(
::grpc::ServerContext *context,
const ::posix_server::SetSockOptTimevalRequest *request,
@@ -174,10 +193,11 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
- ::grpc::Status Close(grpc_impl::ServerContext *context,
- const ::posix_server::CloseRequest *request,
- ::posix_server::CloseResponse *response) override {
- response->set_ret(close(request->fd()));
+ ::grpc::Status Socket(grpc_impl::ServerContext *context,
+ const ::posix_server::SocketRequest *request,
+ ::posix_server::SocketResponse *response) override {
+ response->set_fd(
+ socket(request->domain(), request->type(), request->protocol()));
response->set_errno_(errno);
return ::grpc::Status::OK;
}
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
index 53ec49410..4035e1ee6 100644
--- a/test/packetimpact/proto/posix_server.proto
+++ b/test/packetimpact/proto/posix_server.proto
@@ -16,17 +16,6 @@ syntax = "proto3";
package posix_server;
-message SocketRequest {
- int32 domain = 1;
- int32 type = 2;
- int32 protocol = 3;
-}
-
-message SocketResponse {
- int32 fd = 1;
- int32 errno_ = 2; // "errno" may fail to compile in c++.
-}
-
message SockaddrIn {
int32 family = 1;
uint32 port = 2;
@@ -48,6 +37,23 @@ message Sockaddr {
}
}
+message Timeval {
+ int64 seconds = 1;
+ int64 microseconds = 2;
+}
+
+// Request and Response pairs for each Posix service RPC call, sorted.
+
+message AcceptRequest {
+ int32 sockfd = 1;
+}
+
+message AcceptResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
message BindRequest {
int32 sockfd = 1;
Sockaddr addr = 2;
@@ -58,6 +64,15 @@ message BindResponse {
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
+message CloseRequest {
+ int32 fd = 1;
+}
+
+message CloseResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
message GetSockNameRequest {
int32 sockfd = 1;
}
@@ -78,14 +93,15 @@ message ListenResponse {
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
-message AcceptRequest {
+message SendRequest {
int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
}
-message AcceptResponse {
- int32 fd = 1;
- int32 errno_ = 2; // "errno" may fail to compile in c++.
- Sockaddr addr = 3;
+message SendResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
}
message SetSockOptRequest {
@@ -100,9 +116,16 @@ message SetSockOptResponse {
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
-message Timeval {
- int64 seconds = 1;
- int64 microseconds = 2;
+message SetSockOptIntRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ int32 intval = 4;
+}
+
+message SetSockOptIntResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
}
message SetSockOptTimevalRequest {
@@ -117,12 +140,14 @@ message SetSockOptTimevalResponse {
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
-message CloseRequest {
- int32 fd = 1;
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
}
-message CloseResponse {
- int32 ret = 1;
+message SocketResponse {
+ int32 fd = 1;
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
@@ -139,26 +164,30 @@ message RecvResponse {
}
service Posix {
- // Call socket() on the DUT.
- rpc Socket(SocketRequest) returns (SocketResponse);
+ // Call accept() on the DUT.
+ rpc Accept(AcceptRequest) returns (AcceptResponse);
// Call bind() on the DUT.
rpc Bind(BindRequest) returns (BindResponse);
+ // Call close() on the DUT.
+ rpc Close(CloseRequest) returns (CloseResponse);
// Call getsockname() on the DUT.
rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
// Call listen() on the DUT.
rpc Listen(ListenRequest) returns (ListenResponse);
- // Call accept() on the DUT.
- rpc Accept(AcceptRequest) returns (AcceptResponse);
+ // Call send() on the DUT.
+ rpc Send(SendRequest) returns (SendResponse);
// Call setsockopt() on the DUT. You should prefer one of the other
// SetSockOpt* functions with a more structured optval or else you may get the
// encoding wrong, such as making a bad assumption about the server's word
// sizes or endianness.
rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse);
+ // Call setsockopt() on the DUT with an int optval.
+ rpc SetSockOptInt(SetSockOptIntRequest) returns (SetSockOptIntResponse);
// Call setsockopt() on the DUT with a Timeval optval.
rpc SetSockOptTimeval(SetSockOptTimevalRequest)
returns (SetSockOptTimevalResponse);
- // Call close() on the DUT.
- rpc Close(CloseRequest) returns (CloseResponse);
+ // Call socket() on the DUT.
+ rpc Socket(SocketRequest) returns (SocketResponse);
// Call recv() on the DUT.
rpc Recv(RecvRequest) returns (RecvResponse);
}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 4a9d8efa6..838a10ffe 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(
default_visibility = ["//test/packetimpact:__subpackages__"],
@@ -30,3 +30,11 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "testbench_test",
+ size = "small",
+ srcs = ["layers_test.go"],
+ library = ":testbench",
+ deps = ["//pkg/tcpip"],
+)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 8d1f562ee..79c0ccf5c 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -21,6 +21,7 @@ import (
"fmt"
"math/rand"
"net"
+ "strings"
"testing"
"time"
@@ -187,9 +188,19 @@ func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
conn.SendFrame(conn.CreateFrame(tcp, additionalLayers...))
}
-// Recv gets a packet from the sniffer within the timeout provided. If no packet
-// arrives before the timeout, it returns nil.
+// Recv gets a packet from the sniffer within the timeout provided.
+// If no packet arrives before the timeout, it returns nil.
func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
+ layers := conn.RecvFrame(timeout)
+ if tcpLayerIndex < len(layers) {
+ return layers[tcpLayerIndex].(*TCP)
+ }
+ return nil
+}
+
+// RecvFrame gets a frame (of type Layers) within the timeout provided.
+// If no frame arrives before the timeout, it returns nil.
+func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers {
deadline := time.Now().Add(timeout)
for {
timeout = time.Until(deadline)
@@ -200,11 +211,7 @@ func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
if b == nil {
break
}
- layers, err := ParseEther(b)
- if err != nil {
- conn.t.Logf("can't parse frame: %s", err)
- continue // Ignore packets that can't be parsed.
- }
+ layers := Parse(ParseEther, b)
if !conn.incoming.match(layers) {
continue // Ignore packets that don't match the expected incoming.
}
@@ -216,27 +223,59 @@ func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
for i := tcpLayerIndex + 1; i < len(layers); i++ {
conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length()))
}
- return tcpHeader
+ return layers
}
return nil
}
// Expect a packet that matches the provided tcp within the timeout specified.
-// If it doesn't arrive in time, the test fails.
-func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP {
+// If it doesn't arrive in time, it returns nil.
+func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
+ // We cannot implement this directly using ExpectFrame as we cannot specify
+ // the Payload part.
deadline := time.Now().Add(timeout)
+ var allTCP []string
for {
- timeout = time.Until(deadline)
- if timeout <= 0 {
- return nil
+ var gotTCP *TCP
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotTCP = conn.Recv(timeout)
}
- gotTCP := conn.Recv(timeout)
if gotTCP == nil {
- return nil
+ return nil, fmt.Errorf("got %d packets:\n%s", len(allTCP), strings.Join(allTCP, "\n"))
}
if tcp.match(gotTCP) {
- return gotTCP
+ return gotTCP, nil
}
+ allTCP = append(allTCP, gotTCP.String())
+ }
+}
+
+// ExpectFrame expects a frame that matches the specified layers within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *TCPIPv4) ExpectFrame(layers Layers, timeout time.Duration) Layers {
+ deadline := time.Now().Add(timeout)
+ for {
+ timeout = time.Until(deadline)
+ if timeout <= 0 {
+ return nil
+ }
+ gotLayers := conn.RecvFrame(timeout)
+ if layers.match(gotLayers) {
+ return gotLayers
+ }
+ }
+}
+
+// ExpectData is a convenient method that expects a TCP packet along with
+// the payload to arrive within the timeout specified. If it doesn't arrive
+// in time, it causes a fatal test failure.
+func (conn *TCPIPv4) ExpectData(tcp TCP, data []byte, timeout time.Duration) {
+ expected := []Layer{&Ether{}, &IPv4{}, &tcp}
+ if len(data) > 0 {
+ expected = append(expected, &Payload{Bytes: data})
+ }
+ if conn.ExpectFrame(expected, timeout) == nil {
+ conn.t.Fatalf("expected to get a TCP frame %s with payload %x", &tcp, data)
}
}
@@ -246,10 +285,11 @@ func (conn *TCPIPv4) Handshake() {
conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
- conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
- if conn.SynAck == nil {
- conn.t.Fatalf("didn't get synack during handshake")
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if synAck == nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
}
+ conn.SynAck = synAck
// Send an ACK.
conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
@@ -374,11 +414,7 @@ func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP {
if b == nil {
break
}
- layers, err := ParseEther(b)
- if err != nil {
- conn.t.Logf("can't parse frame: %s", err)
- continue // Ignore packets that can't be parsed.
- }
+ layers := Parse(ParseEther, b)
if !conn.incoming.match(layers) {
continue // Ignore packets that don't match the expected incoming.
}
@@ -389,19 +425,20 @@ func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP {
// Expect a packet that matches the provided udp within the timeout specified.
// If it doesn't arrive in time, the test fails.
-func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) *UDP {
+func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
deadline := time.Now().Add(timeout)
+ var allUDP []string
for {
- timeout = time.Until(deadline)
- if timeout <= 0 {
- return nil
+ var gotUDP *UDP
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotUDP = conn.Recv(timeout)
}
- gotUDP := conn.Recv(timeout)
if gotUDP == nil {
- return nil
+ return nil, fmt.Errorf("got %d packets:\n%s", len(allUDP), strings.Join(allUDP, "\n"))
}
if udp.match(gotUDP) {
- return gotUDP
+ return gotUDP, nil
}
+ allUDP = append(allUDP, gotUDP.String())
}
}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index d102dc7bb..9335909c0 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -65,33 +65,6 @@ func (dut *DUT) TearDown() {
dut.conn.Close()
}
-// SocketWithErrno calls socket on the DUT and returns the fd and errno.
-func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
- dut.t.Helper()
- req := pb.SocketRequest{
- Domain: domain,
- Type: typ,
- Protocol: proto,
- }
- ctx := context.Background()
- resp, err := dut.posixServer.Socket(ctx, &req)
- if err != nil {
- dut.t.Fatalf("failed to call Socket: %s", err)
- }
- return resp.GetFd(), syscall.Errno(resp.GetErrno_())
-}
-
-// Socket calls socket on the DUT and returns the file descriptor. If socket
-// fails on the DUT, the test ends.
-func (dut *DUT) Socket(domain, typ, proto int32) int32 {
- dut.t.Helper()
- fd, err := dut.SocketWithErrno(domain, typ, proto)
- if fd < 0 {
- dut.t.Fatalf("failed to create socket: %s", err)
- }
- return fd
-}
-
func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
dut.t.Helper()
switch s := sa.(type) {
@@ -142,6 +115,88 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
return nil
}
+// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
+// proto, and bound to the IP address addr. Returns the new file descriptor and
+// the port that was selected on the DUT.
+func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
+ dut.t.Helper()
+ var fd int32
+ if addr.To4() != nil {
+ fd = dut.Socket(unix.AF_INET, typ, proto)
+ sa := unix.SockaddrInet4{}
+ copy(sa.Addr[:], addr.To4())
+ dut.Bind(fd, &sa)
+ } else if addr.To16() != nil {
+ fd = dut.Socket(unix.AF_INET6, typ, proto)
+ sa := unix.SockaddrInet6{}
+ copy(sa.Addr[:], addr.To16())
+ dut.Bind(fd, &sa)
+ } else {
+ dut.t.Fatal("unknown ip addr type for remoteIP")
+ }
+ sa := dut.GetSockName(fd)
+ var port int
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ port = s.Port
+ case *unix.SockaddrInet6:
+ port = s.Port
+ default:
+ dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ }
+ return fd, uint16(port)
+}
+
+// CreateListener makes a new TCP connection. If it fails, the test ends.
+func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+ fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4))
+ dut.Listen(fd, backlog)
+ return fd, remotePort
+}
+
+// All the functions that make gRPC calls to the Posix service are below, sorted
+// alphabetically.
+
+// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// AcceptWithErrno.
+func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
+ if fd < 0 {
+ dut.t.Fatalf("failed to accept: %s", err)
+ }
+ return fd, sa
+}
+
+// AcceptWithErrno calls accept on the DUT.
+func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.AcceptRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.Accept(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Accept: %s", err)
+ }
+ return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is
+// needed, use BindWithErrno.
+func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.BindWithErrno(ctx, fd, sa)
+ if ret != 0 {
+ dut.t.Fatalf("failed to bind socket: %s", err)
+ }
+}
+
// BindWithErrno calls bind on the DUT.
func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
dut.t.Helper()
@@ -156,30 +211,30 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
-// succeed. If more control over the timeout or error handling is
-// needed, use BindWithErrno.
-func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// CloseWithErrno.
+func (dut *DUT) Close(fd int32) {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
- ret, err := dut.BindWithErrno(ctx, fd, sa)
+ ret, err := dut.CloseWithErrno(ctx, fd)
if ret != 0 {
- dut.t.Fatalf("failed to bind socket: %s", err)
+ dut.t.Fatalf("failed to close: %s", err)
}
}
-// GetSockNameWithErrno calls getsockname on the DUT.
-func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+// CloseWithErrno calls close on the DUT.
+func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
dut.t.Helper()
- req := pb.GetSockNameRequest{
- Sockfd: sockfd,
+ req := pb.CloseRequest{
+ Fd: fd,
}
- resp, err := dut.posixServer.GetSockName(ctx, &req)
+ resp, err := dut.posixServer.Close(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Bind: %s", err)
+ dut.t.Fatalf("failed to call Close: %s", err)
}
- return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
// GetSockName calls getsockname on the DUT and causes a fatal test failure if
@@ -196,6 +251,32 @@ func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
return sa
}
+// GetSockNameWithErrno calls getsockname on the DUT.
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.GetSockNameRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.GetSockName(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// ListenWithErrno.
+func (dut *DUT) Listen(sockfd, backlog int32) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
+ if ret != 0 {
+ dut.t.Fatalf("failed to listen: %s", err)
+ }
+}
+
// ListenWithErrno calls listen on the DUT.
func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) {
dut.t.Helper()
@@ -210,44 +291,48 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// Send calls send on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
-// ListenWithErrno.
-func (dut *DUT) Listen(sockfd, backlog int32) {
+// SendWithErrno.
+func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
- ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
- if ret != 0 {
- dut.t.Fatalf("failed to listen: %s", err)
+ ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to send: %s", err)
}
+ return ret
}
-// AcceptWithErrno calls accept on the DUT.
-func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+// SendWithErrno calls send on the DUT.
+func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) {
dut.t.Helper()
- req := pb.AcceptRequest{
+ req := pb.SendRequest{
Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
}
- resp, err := dut.posixServer.Accept(ctx, &req)
+ resp, err := dut.posixServer.Send(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Accept: %s", err)
+ dut.t.Fatalf("failed to call Send: %s", err)
}
- return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
-// succeed. If more control over the timeout or error handling is needed, use
-// AcceptWithErrno.
-func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific SetSockOptXxx function.
+func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
- fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
- if fd < 0 {
- dut.t.Fatalf("failed to accept: %s", err)
+ ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOpt: %s", err)
}
- return fd, sa
}
// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
@@ -268,38 +353,31 @@ func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
-// doesn't succeed. If more control over the timeout or error handling is
-// needed, use SetSockOptWithErrno. Because endianess and the width of values
-// might differ between the testbench and DUT architectures, prefer to use a
-// more specific SetSockOptXxx function.
-func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
+// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use SetSockOptIntWithErrno.
+func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
- ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ dut.t.Fatalf("failed to SetSockOptInt: %s", err)
}
}
-// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
-// bytes.
-func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+// SetSockOptIntWithErrno calls setsockopt with an integer optval.
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) {
dut.t.Helper()
- timeval := pb.Timeval{
- Seconds: int64(tv.Sec),
- Microseconds: int64(tv.Usec),
- }
- req := pb.SetSockOptTimevalRequest{
+ req := pb.SetSockOptIntRequest{
Sockfd: sockfd,
Level: level,
Optname: optname,
- Timeval: &timeval,
+ Intval: optval,
}
- resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req)
+ resp, err := dut.posixServer.SetSockOptInt(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err)
+ dut.t.Fatalf("failed to call SetSockOptInt: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -317,96 +395,79 @@ func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval
}
}
-// RecvWithErrno calls recv on the DUT.
-func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
+// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
+// bytes.
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
dut.t.Helper()
- req := pb.RecvRequest{
- Sockfd: sockfd,
- Len: len,
- Flags: flags,
+ timeval := pb.Timeval{
+ Seconds: int64(tv.Sec),
+ Microseconds: int64(tv.Usec),
}
- resp, err := dut.posixServer.Recv(ctx, &req)
+ req := pb.SetSockOptTimevalRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Timeval: &timeval,
+ }
+ resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Recv: %s", err)
+ dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err)
}
- return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
-// succeed. If more control over the timeout or error handling is needed, use
-// RecvWithErrno.
-func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
+// Socket calls socket on the DUT and returns the file descriptor. If socket
+// fails on the DUT, the test ends.
+func (dut *DUT) Socket(domain, typ, proto int32) int32 {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
- ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
- if ret == -1 {
- dut.t.Fatalf("failed to recv: %s", err)
+ fd, err := dut.SocketWithErrno(domain, typ, proto)
+ if fd < 0 {
+ dut.t.Fatalf("failed to create socket: %s", err)
}
- return buf
+ return fd
}
-// CloseWithErrno calls close on the DUT.
-func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
+// SocketWithErrno calls socket on the DUT and returns the fd and errno.
+func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
dut.t.Helper()
- req := pb.CloseRequest{
- Fd: fd,
+ req := pb.SocketRequest{
+ Domain: domain,
+ Type: typ,
+ Protocol: proto,
}
- resp, err := dut.posixServer.Close(ctx, &req)
+ ctx := context.Background()
+ resp, err := dut.posixServer.Socket(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Close: %s", err)
+ dut.t.Fatalf("failed to call Socket: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetFd(), syscall.Errno(resp.GetErrno_())
}
-// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
-// CloseWithErrno.
-func (dut *DUT) Close(fd int32) {
+// RecvWithErrno.
+func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
dut.t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
defer cancel()
- ret, err := dut.CloseWithErrno(ctx, fd)
- if ret != 0 {
- dut.t.Fatalf("failed to close: %s", err)
+ ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to recv: %s", err)
}
+ return buf
}
-// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
-// proto, and bound to the IP address addr. Returns the new file descriptor and
-// the port that was selected on the DUT.
-func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
+// RecvWithErrno calls recv on the DUT.
+func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
dut.t.Helper()
- var fd int32
- if addr.To4() != nil {
- fd = dut.Socket(unix.AF_INET, typ, proto)
- sa := unix.SockaddrInet4{}
- copy(sa.Addr[:], addr.To4())
- dut.Bind(fd, &sa)
- } else if addr.To16() != nil {
- fd = dut.Socket(unix.AF_INET6, typ, proto)
- sa := unix.SockaddrInet6{}
- copy(sa.Addr[:], addr.To16())
- dut.Bind(fd, &sa)
- } else {
- dut.t.Fatal("unknown ip addr type for remoteIP")
+ req := pb.RecvRequest{
+ Sockfd: sockfd,
+ Len: len,
+ Flags: flags,
}
- sa := dut.GetSockName(fd)
- var port int
- switch s := sa.(type) {
- case *unix.SockaddrInet4:
- port = s.Port
- case *unix.SockaddrInet6:
- port = s.Port
- default:
- dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ resp, err := dut.posixServer.Recv(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Recv: %s", err)
}
- return fd, uint16(port)
-}
-
-// CreateListener makes a new TCP connection. If it fails, the test ends.
-func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
- fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4))
- dut.Listen(fd, backlog)
- return fd, remotePort
+ return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index d7434c3d2..b467c15cc 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -17,6 +17,7 @@ package testbench
import (
"fmt"
"reflect"
+ "strings"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
@@ -32,6 +33,8 @@ import (
// Layer contains all the fields of the encapsulation. Each field is a pointer
// and may be nil.
type Layer interface {
+ fmt.Stringer
+
// toBytes converts the Layer into bytes. In places where the Layer's field
// isn't nil, the value that is pointed to is used. When the field is nil, a
// reasonable default for the Layer is used. For example, "64" for IPv4 TTL
@@ -43,7 +46,8 @@ type Layer interface {
// match checks if the current Layer matches the provided Layer. If either
// Layer has a nil in a given field, that field is considered matching.
- // Otherwise, the values pointed to by the fields must match.
+ // Otherwise, the values pointed to by the fields must match. The LayerBase is
+ // ignored.
match(Layer) bool
// length in bytes of the current encapsulation
@@ -84,18 +88,39 @@ func (lb *LayerBase) setPrev(l Layer) {
lb.prevLayer = l
}
+// equalLayer compares that two Layer structs match while ignoring field in
+// which either input has a nil and also ignoring the LayerBase of the inputs.
func equalLayer(x, y Layer) bool {
+ // opt ignores comparison pairs where either of the inputs is a nil.
opt := cmp.FilterValues(func(x, y interface{}) bool {
- if reflect.ValueOf(x).Kind() == reflect.Ptr && reflect.ValueOf(x).IsNil() {
- return true
- }
- if reflect.ValueOf(y).Kind() == reflect.Ptr && reflect.ValueOf(y).IsNil() {
- return true
+ for _, l := range []interface{}{x, y} {
+ v := reflect.ValueOf(l)
+ if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
+ return true
+ }
}
return false
-
}, cmp.Ignore())
- return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{}))
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
+}
+
+func stringLayer(l Layer) string {
+ v := reflect.ValueOf(l).Elem()
+ t := v.Type()
+ var ret []string
+ for i := 0; i < v.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := v.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ ret = append(ret, fmt.Sprintf("%s:%v", t.Name, reflect.Indirect(v)))
+ }
+ return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
}
// Ether can construct and match an ethernet encapsulation.
@@ -106,6 +131,10 @@ type Ether struct {
Type *tcpip.NetworkProtocolNumber
}
+func (l *Ether) String() string {
+ return stringLayer(l)
+}
+
func (l *Ether) toBytes() ([]byte, error) {
b := make([]byte, header.EthernetMinimumSize)
h := header.Ethernet(b)
@@ -124,7 +153,7 @@ func (l *Ether) toBytes() ([]byte, error) {
fields.Type = header.IPv4ProtocolNumber
default:
// TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n)
+ return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
}
}
h.Encode(fields)
@@ -143,27 +172,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol
return &v
}
+// LayerParser parses the input bytes and returns a Layer along with the next
+// LayerParser to run. If there is no more parsing to do, the returned
+// LayerParser is nil.
+type LayerParser func([]byte) (Layer, LayerParser)
+
+// Parse parses bytes starting with the first LayerParser and using successive
+// LayerParsers until all the bytes are parsed.
+func Parse(parser LayerParser, b []byte) Layers {
+ var layers Layers
+ for {
+ var layer Layer
+ layer, parser = parser(b)
+ layers = append(layers, layer)
+ if parser == nil {
+ break
+ }
+ b = b[layer.length():]
+ }
+ layers.linkLayers()
+ return layers
+}
+
// ParseEther parses the bytes assuming that they start with an ethernet header
// and continues parsing further encapsulations.
-func ParseEther(b []byte) (Layers, error) {
+func ParseEther(b []byte) (Layer, LayerParser) {
h := header.Ethernet(b)
ether := Ether{
SrcAddr: LinkAddress(h.SourceAddress()),
DstAddr: LinkAddress(h.DestinationAddress()),
Type: NetworkProtocolNumber(h.Type()),
}
- layers := Layers{&ether}
+ var nextParser LayerParser
switch h.Type() {
case header.IPv4ProtocolNumber:
- moreLayers, err := ParseIPv4(b[ether.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ nextParser = ParseIPv4
default:
- // TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %#v", b)
+ // Assume that the rest is a payload.
+ nextParser = ParsePayload
}
+ return &ether, nextParser
}
func (l *Ether) match(other Layer) bool {
@@ -190,6 +238,10 @@ type IPv4 struct {
DstAddr *tcpip.Address
}
+func (l *IPv4) String() string {
+ return stringLayer(l)
+}
+
func (l *IPv4) toBytes() ([]byte, error) {
b := make([]byte, header.IPv4MinimumSize)
h := header.IPv4(b)
@@ -241,7 +293,7 @@ func (l *IPv4) toBytes() ([]byte, error) {
fields.Protocol = uint8(header.UDPProtocolNumber)
default:
// TODO(b/150301488): Support more protocols as needed.
- return nil, fmt.Errorf("can't deduce the ip header's next protocol: %#v", n)
+ return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
}
}
if l.SrcAddr != nil {
@@ -280,7 +332,7 @@ func Address(v tcpip.Address) *tcpip.Address {
// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and
// continues parsing further encapsulations.
-func ParseIPv4(b []byte) (Layers, error) {
+func ParseIPv4(b []byte) (Layer, LayerParser) {
h := header.IPv4(b)
tos, _ := h.TOS()
ipv4 := IPv4{
@@ -296,22 +348,17 @@ func ParseIPv4(b []byte) (Layers, error) {
SrcAddr: Address(h.SourceAddress()),
DstAddr: Address(h.DestinationAddress()),
}
- layers := Layers{&ipv4}
+ var nextParser LayerParser
switch h.TransportProtocol() {
case header.TCPProtocolNumber:
- moreLayers, err := ParseTCP(b[ipv4.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ nextParser = ParseTCP
case header.UDPProtocolNumber:
- moreLayers, err := ParseUDP(b[ipv4.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ nextParser = ParseUDP
+ default:
+ // Assume that the rest is a payload.
+ nextParser = ParsePayload
}
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol())
+ return &ipv4, nextParser
}
func (l *IPv4) match(other Layer) bool {
@@ -339,6 +386,10 @@ type TCP struct {
UrgentPointer *uint16
}
+func (l *TCP) String() string {
+ return stringLayer(l)
+}
+
func (l *TCP) toBytes() ([]byte, error) {
b := make([]byte, header.TCPMinimumSize)
h := header.TCP(b)
@@ -433,7 +484,7 @@ func Uint32(v uint32) *uint32 {
// ParseTCP parses the bytes assuming that they start with a tcp header and
// continues parsing further encapsulations.
-func ParseTCP(b []byte) (Layers, error) {
+func ParseTCP(b []byte) (Layer, LayerParser) {
h := header.TCP(b)
tcp := TCP{
SrcPort: Uint16(h.SourcePort()),
@@ -446,12 +497,7 @@ func ParseTCP(b []byte) (Layers, error) {
Checksum: Uint16(h.Checksum()),
UrgentPointer: Uint16(h.UrgentPointer()),
}
- layers := Layers{&tcp}
- moreLayers, err := ParsePayload(b[tcp.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ return &tcp, ParsePayload
}
func (l *TCP) match(other Layer) bool {
@@ -480,6 +526,10 @@ type UDP struct {
Checksum *uint16
}
+func (l *UDP) String() string {
+ return stringLayer(l)
+}
+
func (l *UDP) toBytes() ([]byte, error) {
b := make([]byte, header.UDPMinimumSize)
h := header.UDP(b)
@@ -516,8 +566,8 @@ func setUDPChecksum(h *header.UDP, udp *UDP) error {
}
// ParseUDP parses the bytes assuming that they start with a udp header and
-// continues parsing further encapsulations.
-func ParseUDP(b []byte) (Layers, error) {
+// returns the parsed layer and the next parser to use.
+func ParseUDP(b []byte) (Layer, LayerParser) {
h := header.UDP(b)
udp := UDP{
SrcPort: Uint16(h.SourcePort()),
@@ -525,12 +575,7 @@ func ParseUDP(b []byte) (Layers, error) {
Length: Uint16(h.Length()),
Checksum: Uint16(h.Checksum()),
}
- layers := Layers{&udp}
- moreLayers, err := ParsePayload(b[udp.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ return &udp, ParsePayload
}
func (l *UDP) match(other Layer) bool {
@@ -556,13 +601,17 @@ type Payload struct {
Bytes []byte
}
+func (l *Payload) String() string {
+ return stringLayer(l)
+}
+
// ParsePayload parses the bytes assuming that they start with a payload and
// continue to the end. There can be no further encapsulations.
-func ParsePayload(b []byte) (Layers, error) {
+func ParsePayload(b []byte) (Layer, LayerParser) {
payload := Payload{
Bytes: b,
}
- return Layers{&payload}, nil
+ return &payload, nil
}
func (l *Payload) toBytes() ([]byte, error) {
@@ -580,15 +629,24 @@ func (l *Payload) length() int {
// Layers is an array of Layer and supports similar functions to Layer.
type Layers []Layer
-func (ls *Layers) toBytes() ([]byte, error) {
+// linkLayers sets the linked-list ponters in ls.
+func (ls *Layers) linkLayers() {
for i, l := range *ls {
if i > 0 {
l.setPrev((*ls)[i-1])
+ } else {
+ l.setPrev(nil)
}
if i+1 < len(*ls) {
l.setNext((*ls)[i+1])
+ } else {
+ l.setNext(nil)
}
}
+}
+
+func (ls *Layers) toBytes() ([]byte, error) {
+ ls.linkLayers()
outBytes := []byte{}
for _, l := range *ls {
layerBytes, err := l.toBytes()
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
new file mode 100644
index 000000000..8ffc26bf9
--- /dev/null
+++ b/test/packetimpact/testbench/layers_test.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import "testing"
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+func TestLayerMatch(t *testing.T) {
+ var nilPayload *Payload
+ noPayload := &Payload{}
+ emptyPayload := &Payload{Bytes: []byte{}}
+ fullPayload := &Payload{Bytes: []byte{1, 2, 3}}
+ emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}}
+ fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}}
+ for _, tt := range []struct {
+ a, b Layer
+ want bool
+ }{
+ {nilPayload, nilPayload, true},
+ {nilPayload, noPayload, true},
+ {nilPayload, emptyPayload, true},
+ {nilPayload, fullPayload, true},
+ {noPayload, noPayload, true},
+ {noPayload, emptyPayload, true},
+ {noPayload, fullPayload, true},
+ {emptyPayload, emptyPayload, true},
+ {emptyPayload, fullPayload, false},
+ {fullPayload, fullPayload, true},
+ {emptyTCP, fullTCP, true},
+ } {
+ if got := tt.a.match(tt.b); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want)
+ }
+ if got := tt.b.match(tt.a); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestLayerStringFormat(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ l Layer
+ want string
+ }{
+ {
+ name: "TCP",
+ l: &TCP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ SeqNum: Uint32(3452155723),
+ AckNum: Uint32(2596996163),
+ DataOffset: Uint8(5),
+ Flags: Uint8(20),
+ WindowSize: Uint16(64240),
+ Checksum: Uint16(0x2e2b),
+ },
+ want: "&testbench.TCP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "SeqNum:3452155723 " +
+ "AckNum:2596996163 " +
+ "DataOffset:5 " +
+ "Flags:20 " +
+ "WindowSize:64240 " +
+ "Checksum:11819" +
+ "}",
+ },
+ {
+ name: "UDP",
+ l: &UDP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ Length: Uint16(12),
+ },
+ want: "&testbench.UDP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "Length:12" +
+ "}",
+ },
+ {
+ name: "IPv4",
+ l: &IPv4{
+ IHL: Uint8(5),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(0),
+ Flags: Uint8(2),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(6),
+ Checksum: Uint16(0x2e2b),
+ SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})),
+ DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})),
+ },
+ want: "&testbench.IPv4{" +
+ "IHL:5 " +
+ "TOS:0 " +
+ "TotalLength:44 " +
+ "ID:0 " +
+ "Flags:2 " +
+ "FragmentOffset:0 " +
+ "TTL:64 " +
+ "Protocol:6 " +
+ "Checksum:11819 " +
+ "SrcAddr:197.34.63.10 " +
+ "DstAddr:197.34.63.20" +
+ "}",
+ },
+ {
+ name: "Ether",
+ l: &Ether{
+ SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})),
+ DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})),
+ Type: NetworkProtocolNumber(4),
+ },
+ want: "&testbench.Ether{" +
+ "SrcAddr:02:42:c5:22:3f:0a " +
+ "DstAddr:02:42:c5:22:3f:14 " +
+ "Type:4" +
+ "}",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.l.String(); got != tt.want {
+ t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 0c7d0f979..0074484f7 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -47,6 +47,12 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
if err != nil {
return Sniffer{}, err
}
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil {
+ t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err)
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil {
+ t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
+ }
return Sniffer{
t: t,
fd: snifferFd,
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 9a4d66ea9..956b1addf 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -28,6 +28,30 @@ packetimpact_go_test(
],
)
+packetimpact_go_test(
+ name = "tcp_window_shrink",
+ srcs = ["tcp_window_shrink_test.go"],
+ # TODO(b/153202472): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_noaccept_close_rst",
+ srcs = ["tcp_noaccept_close_rst_test.go"],
+ # TODO(b/153380909): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
sh_binary(
name = "test_runner",
srcs = ["test_runner.sh"],
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
index 2b3f39045..90e16ef65 100644
--- a/test/packetimpact/tests/fin_wait2_timeout_test.go
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -47,20 +47,20 @@ func TestFinWait2Timeout(t *testing.T) {
}
dut.Close(acceptFd)
- if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); gotOne == nil {
- t.Fatal("expected a FIN-ACK within 1 second but got none")
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
}
conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
time.Sleep(5 * time.Second)
conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
if tt.linger2 {
- if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); gotOne == nil {
- t.Fatal("expected a RST packet within a second but got none")
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected a RST packet within a second but got none: %s", err)
}
} else {
- if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); gotOne != nil {
- t.Fatal("expected no RST packets within ten seconds but got one")
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); err == nil {
+ t.Fatalf("expected no RST packets within ten seconds but got one: %s", err)
}
}
})
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
new file mode 100644
index 000000000..7ebdd1950
--- /dev/null
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_noaccept_close_rst_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestTcpNoAcceptCloseReset(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn.Handshake()
+ defer conn.Close()
+ dut.Close(listenFd)
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ t.Fatalf("expected a RST-ACK packet but got none: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
new file mode 100644
index 000000000..b48cc6491
--- /dev/null
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -0,0 +1,58 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_window_shrink_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestWindowShrink(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Handshake()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+
+ dut.Send(acceptFd, sampleData, 0)
+ conn.ExpectData(tb.TCP{}, sampleData, time.Second)
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+
+ dut.Send(acceptFd, sampleData, 0)
+ dut.Send(acceptFd, sampleData, 0)
+ conn.ExpectData(tb.TCP{}, sampleData, time.Second)
+ conn.ExpectData(tb.TCP{}, sampleData, time.Second)
+ // We close our receiving window here
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+
+ dut.Send(acceptFd, []byte("Sample Data"), 0)
+ // Note: There is another kind of zero-window probing which Windows uses (by sending one
+ // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change
+ // the following lines.
+ conn.ExpectData(tb.TCP{SeqNum: tb.Uint32(uint32(conn.RemoteSeqNum - 1))}, nil, time.Second)
+}