summaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2020-04-28 18:49:19 -0700
committergVisor bot <gvisor-bot@google.com>2020-04-28 18:50:44 -0700
commit24abccbc1c3b7b0dd06b6da97e5b4c90c8c13907 (patch)
tree95911e3341b7296f7d80f5c0ee82e466e88184e2 /test
parentf93f2fda74f31246e8866783f6c4be2318bdedd6 (diff)
Internal change.
PiperOrigin-RevId: 308940886
Diffstat (limited to 'test')
-rw-r--r--test/packetimpact/dut/posix_server.cc163
-rw-r--r--test/packetimpact/proto/posix_server.proto73
-rw-r--r--test/packetimpact/testbench/connections.go120
-rw-r--r--test/packetimpact/testbench/dut.go153
-rw-r--r--test/packetimpact/testbench/layers.go104
-rw-r--r--test/packetimpact/tests/BUILD13
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go209
7 files changed, 748 insertions, 87 deletions
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
index 86e580c6f..cb499b0b1 100644
--- a/test/packetimpact/dut/posix_server.cc
+++ b/test/packetimpact/dut/posix_server.cc
@@ -60,6 +60,45 @@
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr");
}
+::grpc::Status proto_to_sockaddr(const posix_server::Sockaddr &sockaddr_proto,
+ sockaddr_storage *addr) {
+ switch (sockaddr_proto.sockaddr_case()) {
+ case posix_server::Sockaddr::SockaddrCase::kIn: {
+ auto proto_in = sockaddr_proto.in();
+ if (proto_in.addr().size() != 4) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv4 address must be 4 bytes");
+ }
+ auto addr_in = reinterpret_cast<sockaddr_in *>(addr);
+ addr_in->sin_family = proto_in.family();
+ addr_in->sin_port = htons(proto_in.port());
+ proto_in.addr().copy(reinterpret_cast<char *>(&addr_in->sin_addr.s_addr),
+ 4);
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::kIn6: {
+ auto proto_in6 = sockaddr_proto.in6();
+ if (proto_in6.addr().size() != 16) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv6 address must be 16 bytes");
+ }
+ auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(addr);
+ addr_in6->sin6_family = proto_in6.family();
+ addr_in6->sin6_port = htons(proto_in6.port());
+ addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo());
+ proto_in6.addr().copy(
+ reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ addr_in6->sin6_scope_id = htonl(proto_in6.scope_id());
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown Sockaddr");
+ }
+ return ::grpc::Status::OK;
+}
+
class PosixImpl final : public posix_server::Posix::Service {
::grpc::Status Accept(grpc_impl::ServerContext *context,
const ::posix_server::AcceptRequest *request,
@@ -79,42 +118,13 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"Missing address");
}
- sockaddr_storage addr;
- switch (request->addr().sockaddr_case()) {
- case posix_server::Sockaddr::SockaddrCase::kIn: {
- auto request_in = request->addr().in();
- if (request_in.addr().size() != 4) {
- return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
- "IPv4 address must be 4 bytes");
- }
- auto addr_in = reinterpret_cast<sockaddr_in *>(&addr);
- addr_in->sin_family = request_in.family();
- addr_in->sin_port = htons(request_in.port());
- request_in.addr().copy(
- reinterpret_cast<char *>(&addr_in->sin_addr.s_addr), 4);
- break;
- }
- case posix_server::Sockaddr::SockaddrCase::kIn6: {
- auto request_in6 = request->addr().in6();
- if (request_in6.addr().size() != 16) {
- return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
- "IPv6 address must be 16 bytes");
- }
- auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(&addr);
- addr_in6->sin6_family = request_in6.family();
- addr_in6->sin6_port = htons(request_in6.port());
- addr_in6->sin6_flowinfo = htonl(request_in6.flowinfo());
- request_in6.addr().copy(
- reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
- addr_in6->sin6_scope_id = htonl(request_in6.scope_id());
- break;
- }
- case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
- default:
- return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
- "Unknown Sockaddr");
+ sockaddr_storage addr;
+ auto err = proto_to_sockaddr(request->addr(), &addr);
+ if (!err.ok()) {
+ return err;
}
+
response->set_ret(bind(request->sockfd(),
reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
response->set_errno_(errno);
@@ -129,6 +139,25 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
+ ::grpc::Status Connect(grpc_impl::ServerContext *context,
+ const ::posix_server::ConnectRequest *request,
+ ::posix_server::ConnectResponse *response) override {
+ if (!request->has_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+ auto err = proto_to_sockaddr(request->addr(), &addr);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(connect(
+ request->sockfd(), reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status GetSockName(
grpc_impl::ServerContext *context,
const ::posix_server::GetSockNameRequest *request,
@@ -141,6 +170,48 @@ class PosixImpl final : public posix_server::Posix::Service {
return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
}
+ ::grpc::Status GetSockOpt(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::GetSockOptRequest *request,
+ ::posix_server::GetSockOptResponse *response) override {
+ socklen_t optlen = request->optlen();
+ std::vector<char> buf(optlen);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), buf.data(), &optlen));
+ response->set_errno_(errno);
+ if (optlen >= 0) {
+ response->set_optval(buf.data(), optlen);
+ }
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status GetSockOptInt(
+ ::grpc::ServerContext *context,
+ const ::posix_server::GetSockOptIntRequest *request,
+ ::posix_server::GetSockOptIntResponse *response) override {
+ int opt = 0;
+ socklen_t optlen = sizeof(opt);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), &opt, &optlen));
+ response->set_errno_(errno);
+ response->set_intval(opt);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status GetSockOptTimeval(
+ ::grpc::ServerContext *context,
+ const ::posix_server::GetSockOptTimevalRequest *request,
+ ::posix_server::GetSockOptTimevalResponse *response) override {
+ timeval tv;
+ socklen_t optlen = sizeof(tv);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), &tv, &optlen));
+ response->set_errno_(errno);
+ response->mutable_timeval()->set_seconds(tv.tv_sec);
+ response->mutable_timeval()->set_microseconds(tv.tv_usec);
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status Listen(grpc_impl::ServerContext *context,
const ::posix_server::ListenRequest *request,
::posix_server::ListenResponse *response) override {
@@ -158,6 +229,26 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
+ ::grpc::Status SendTo(::grpc::ServerContext *context,
+ const ::posix_server::SendToRequest *request,
+ ::posix_server::SendToResponse *response) override {
+ if (!request->has_dest_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+ auto err = proto_to_sockaddr(request->dest_addr(), &addr);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(::sendto(
+ request->sockfd(), request->buf().data(), request->buf().size(),
+ request->flags(), reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status SetSockOpt(
grpc_impl::ServerContext *context,
const ::posix_server::SetSockOptRequest *request,
@@ -208,8 +299,10 @@ class PosixImpl final : public posix_server::Posix::Service {
std::vector<char> buf(request->len());
response->set_ret(
recv(request->sockfd(), buf.data(), buf.size(), request->flags()));
+ if (response->ret() >= 0) {
+ response->set_buf(buf.data(), response->ret());
+ }
response->set_errno_(errno);
- response->set_buf(buf.data(), response->ret());
return ::grpc::Status::OK;
}
};
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
index 4035e1ee6..ab5ba1c85 100644
--- a/test/packetimpact/proto/posix_server.proto
+++ b/test/packetimpact/proto/posix_server.proto
@@ -73,6 +73,16 @@ message CloseResponse {
int32 errno_ = 2; // "errno" may fail to compile in c++.
}
+message ConnectRequest {
+ int32 sockfd = 1;
+ Sockaddr addr = 2;
+}
+
+message ConnectResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
message GetSockNameRequest {
int32 sockfd = 1;
}
@@ -83,6 +93,43 @@ message GetSockNameResponse {
Sockaddr addr = 3;
}
+message GetSockOptRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ int32 optlen = 4;
+}
+
+message GetSockOptResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ bytes optval = 3;
+}
+
+message GetSockOptIntRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+}
+
+message GetSockOptIntResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ int32 intval = 3;
+}
+
+message GetSockOptTimevalRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+}
+
+message GetSockOptTimevalResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Timeval timeval = 3;
+}
+
message ListenRequest {
int32 sockfd = 1;
int32 backlog = 2;
@@ -104,6 +151,18 @@ message SendResponse {
int32 errno_ = 2;
}
+message SendToRequest {
+ int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
+ Sockaddr dest_addr = 4;
+}
+
+message SendToResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
message SetSockOptRequest {
int32 sockfd = 1;
int32 level = 2;
@@ -170,12 +229,26 @@ service Posix {
rpc Bind(BindRequest) returns (BindResponse);
// Call close() on the DUT.
rpc Close(CloseRequest) returns (CloseResponse);
+ // Call connect() on the DUT.
+ rpc Connect(ConnectRequest) returns (ConnectResponse);
// Call getsockname() on the DUT.
rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
+ // Call getsockopt() on the DUT. You should prefer one of the other
+ // GetSockOpt* functions with a more structured optval or else you may get the
+ // encoding wrong, such as making a bad assumption about the server's word
+ // sizes or endianness.
+ rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse);
+ // Call getsockopt() on the DUT with an int optval.
+ rpc GetSockOptInt(GetSockOptIntRequest) returns (GetSockOptIntResponse);
+ // Call getsockopt() on the DUT with a Timeval optval.
+ rpc GetSockOptTimeval(GetSockOptTimevalRequest)
+ returns (GetSockOptTimevalResponse);
// Call listen() on the DUT.
rpc Listen(ListenRequest) returns (ListenResponse);
// Call send() on the DUT.
rpc Send(SendRequest) returns (SendResponse);
+ // Call sendto() on the DUT.
+ rpc SendTo(SendToRequest) returns (SendToResponse);
// Call setsockopt() on the DUT. You should prefer one of the other
// SetSockOpt* functions with a more structured optval or else you may get the
// encoding wrong, such as making a bad assumption about the server's word
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 2280bd4ee..56ac3fa54 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -39,20 +39,28 @@ var remoteIPv6 = flag.String("remote_ipv6", "", "remote IPv6 address for test pa
var localMAC = flag.String("local_mac", "", "local mac address for test packets")
var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
-// pickPort makes a new socket and returns the socket FD and port. The domain
-// should be AF_INET or AF_INET6. The caller must close the FD when done with
+func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
+ switch sa := sa.(type) {
+ case *unix.SockaddrInet4:
+ return uint16(sa.Port), nil
+ case *unix.SockaddrInet6:
+ return uint16(sa.Port), nil
+ }
+ return 0, fmt.Errorf("sockaddr type %T does not contain port", sa)
+}
+
+// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with
// the port if there is no error.
-func pickPort(domain, typ int) (fd int, port uint16, err error) {
+func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
fd, err = unix.Socket(domain, typ, 0)
if err != nil {
- return -1, 0, err
+ return -1, nil, err
}
defer func() {
if err != nil {
err = multierr.Append(err, unix.Close(fd))
}
}()
- var sa unix.Sockaddr
switch domain {
case unix.AF_INET:
var sa4 unix.SockaddrInet4
@@ -63,31 +71,16 @@ func pickPort(domain, typ int) (fd int, port uint16, err error) {
copy(sa6.Addr[:], net.ParseIP(*localIPv6).To16())
sa = &sa6
default:
- return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
+ return -1, nil, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
}
if err = unix.Bind(fd, sa); err != nil {
- return -1, 0, err
+ return -1, nil, err
}
- newSockAddr, err := unix.Getsockname(fd)
+ sa, err = unix.Getsockname(fd)
if err != nil {
- return -1, 0, err
- }
- switch domain {
- case unix.AF_INET:
- newSockAddrInet4, ok := newSockAddr.(*unix.SockaddrInet4)
- if !ok {
- return -1, 0, fmt.Errorf("can't cast Getsockname result %T to SockaddrInet4", newSockAddr)
- }
- return fd, uint16(newSockAddrInet4.Port), nil
- case unix.AF_INET6:
- newSockAddrInet6, ok := newSockAddr.(*unix.SockaddrInet6)
- if !ok {
- return -1, 0, fmt.Errorf("can't cast Getsockname result %T to SockaddrInet6", newSockAddr)
- }
- return fd, uint16(newSockAddrInet6.Port), nil
- default:
- return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
+ return -1, nil, err
}
+ return fd, sa, nil
}
// layerState stores the state of a layer of a connection.
@@ -282,7 +275,11 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value {
// newTCPState creates a new TCPState.
func newTCPState(domain int, out, in TCP) (*tcpState, error) {
- portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
+ portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM)
+ if err != nil {
+ return nil, err
+ }
+ localPort, err := portFromSockaddr(localAddr)
if err != nil {
return nil, err
}
@@ -385,10 +382,14 @@ type udpState struct {
var _ layerState = (*udpState)(nil)
// newUDPState creates a new udpState.
-func newUDPState(domain int, out, in UDP) (*udpState, error) {
- portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
+func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
+ portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_DGRAM)
if err != nil {
- return nil, err
+ return nil, nil, err
+ }
+ localPort, err := portFromSockaddr(localAddr)
+ if err != nil {
+ return nil, nil, err
}
s := udpState{
out: UDP{SrcPort: &localPort},
@@ -396,12 +397,12 @@ func newUDPState(domain int, out, in UDP) (*udpState, error) {
portPickerFD: portPickerFD,
}
if err := s.out.merge(&out); err != nil {
- return nil, err
+ return nil, nil, err
}
if err := s.in.merge(&in); err != nil {
- return nil, err
+ return nil, nil, err
}
- return &s, nil
+ return &s, localAddr, nil
}
func (s *udpState) outgoing() Layer {
@@ -436,6 +437,7 @@ type Connection struct {
layerStates []layerState
injector Injector
sniffer Sniffer
+ localAddr unix.Sockaddr
t *testing.T
}
@@ -499,7 +501,7 @@ func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Laye
func (conn *Connection) SendFrame(frame Layers) {
outBytes, err := frame.ToBytes()
if err != nil {
- conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
+ conn.t.Fatalf("can't build outgoing packet: %s", err)
}
conn.injector.Send(outBytes)
@@ -545,8 +547,9 @@ func (e *layersError) Error() string {
return e.got.diff(e.want)
}
-// Expect a frame with the final layerStates layer matching the provided Layer
-// within the timeout specified. If it doesn't arrive in time, it returns nil.
+// Expect expects a frame with the final layerStates layer matching the
+// provided Layer within the timeout specified. If it doesn't arrive in time,
+// an error is returned.
func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
// Make a frame that will ignore all but the final layer.
layers := make([]Layer, len(conn.layerStates))
@@ -671,8 +674,8 @@ func (conn *TCPIPv4) Close() {
(*Connection)(conn).Close()
}
-// Expect a frame with the TCP layer matching the provided TCP within the
-// timeout specified. If it doesn't arrive in time, it returns nil.
+// Expect expects a frame with the TCP layer matching the provided TCP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
layer, err := (*Connection)(conn).Expect(&tcp, timeout)
if layer == nil {
@@ -756,7 +759,7 @@ func (conn *IPv6Conn) Close() {
}
// ExpectFrame expects a frame that matches the provided Layers within the
-// timeout specified. If it doesn't arrive in time, it returns nil.
+// timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) {
return (*Connection)(conn).ExpectFrame(frame, timeout)
}
@@ -780,7 +783,7 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- tcpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
+ udpState, localAddr, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
if err != nil {
t.Fatalf("can't make udpState: %s", err)
}
@@ -794,24 +797,61 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
}
return UDPIPv4{
- layerStates: []layerState{etherState, ipv4State, tcpState},
+ layerStates: []layerState{etherState, ipv4State, udpState},
injector: injector,
sniffer: sniffer,
+ localAddr: localAddr,
t: t,
}
}
+// LocalAddr gets the local socket address of this connection.
+func (conn *UDPIPv4) LocalAddr() unix.Sockaddr {
+ return conn.localAddr
+}
+
// CreateFrame builds a frame for the connection with layer overriding defaults
// of the innermost layer and additionalLayers added after it.
func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
}
+// Send a packet with reasonable defaults. Potentially override the UDP layer in
+// the connection with the provided layer and add additionLayers.
+func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
+ (*Connection)(conn).Send(&udp, additionalLayers...)
+}
+
// SendFrame sends a frame on the wire and updates the state of all layers.
func (conn *UDPIPv4) SendFrame(frame Layers) {
(*Connection)(conn).SendFrame(frame)
}
+// SendIP sends a packet with additionalLayers following the IP layer in the
+// connection.
+func (conn *UDPIPv4) SendIP(additionalLayers ...Layer) {
+ var layersToSend Layers
+ for _, s := range conn.layerStates[:len(conn.layerStates)-1] {
+ layersToSend = append(layersToSend, s.outgoing())
+ }
+ layersToSend = append(layersToSend, additionalLayers...)
+ conn.SendFrame(layersToSend)
+}
+
+// Expect expects a frame with the UDP layer matching the provided UDP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
+ layer, err := (*Connection)(conn).Expect(&udp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotUDP, ok := layer.(*UDP)
+ if !ok {
+ conn.t.Fatalf("expected %s to be UDP", layer)
+ }
+ return gotUDP, err
+}
+
// Close frees associated resources held by the UDPIPv4 connection.
func (conn *UDPIPv4) Close() {
(*Connection)(conn).Close()
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index 3f340c6bc..87eeeeb88 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -237,6 +237,33 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
+// Connect calls connect on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use ConnectWithErrno.
+func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.ConnectWithErrno(ctx, fd, sa)
+ if ret != 0 {
+ dut.t.Fatalf("failed to connect socket: %s", err)
+ }
+}
+
+// ConnectWithErrno calls bind on the DUT.
+func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
+ dut.t.Helper()
+ req := pb.ConnectRequest{
+ Sockfd: fd,
+ Addr: dut.sockaddrToProto(sa),
+ }
+ resp, err := dut.posixServer.Connect(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Connect: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
// GetSockName calls getsockname on the DUT and causes a fatal test failure if
// it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockNameWithErrno.
@@ -264,6 +291,102 @@ func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32,
return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
+// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific GetSockOptXxx function.
+func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen)
+ if ret != 0 {
+ dut.t.Fatalf("failed to GetSockOpt: %s", err)
+ }
+ return optval
+}
+
+// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific GetSockOptXxxWithErrno function.
+func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) {
+ dut.t.Helper()
+ req := pb.GetSockOptRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Optlen: optlen,
+ }
+ resp, err := dut.posixServer.GetSockOpt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call GetSockOpt: %s", err)
+ }
+ return resp.GetRet(), resp.GetOptval(), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use GetSockOptIntWithErrno.
+func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname)
+ if ret != 0 {
+ dut.t.Fatalf("failed to GetSockOptInt: %s", err)
+ }
+ return intval
+}
+
+// GetSockOptIntWithErrno calls getsockopt with an integer optval.
+func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) {
+ dut.t.Helper()
+ req := pb.GetSockOptIntRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ }
+ resp, err := dut.posixServer.GetSockOptInt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call GetSockOptInt: %s", err)
+ }
+ return resp.GetRet(), resp.GetIntval(), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockOptTimevalWithErrno.
+func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname)
+ if ret != 0 {
+ dut.t.Fatalf("failed to GetSockOptTimeval: %s", err)
+ }
+ return timeval
+}
+
+// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval.
+func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) {
+ dut.t.Helper()
+ req := pb.GetSockOptTimevalRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ }
+ resp, err := dut.posixServer.GetSockOptTimeval(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call GetSockOptTimeval: %s", err)
+ }
+ timeval := unix.Timeval{
+ Sec: resp.GetTimeval().Seconds,
+ Usec: resp.GetTimeval().Microseconds,
+ }
+ return resp.GetRet(), timeval, syscall.Errno(resp.GetErrno_())
+}
+
// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// ListenWithErrno.
@@ -320,6 +443,36 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
+// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendToWithErrno.
+func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr)
+ if ret == -1 {
+ dut.t.Fatalf("failed to sendto: %s", err)
+ }
+ return ret
+}
+
+// SendToWithErrno calls sendto on the DUT.
+func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
+ dut.t.Helper()
+ req := pb.SendToRequest{
+ Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
+ DestAddr: dut.sockaddrToProto(destAddr),
+ }
+ resp, err := dut.posixServer.SendTo(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("faled to call SendTo: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use SetSockOptWithErrno. Because endianess and the width of values
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 817f5c261..165f62d3b 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -58,7 +58,7 @@ type Layer interface {
next() Layer
// prev gets a pointer to the Layer encapsulating this one.
- prev() Layer
+ Prev() Layer
// setNext sets the pointer to the encapsulated Layer.
setNext(Layer)
@@ -80,7 +80,8 @@ func (lb *LayerBase) next() Layer {
return lb.nextLayer
}
-func (lb *LayerBase) prev() Layer {
+// Prev returns the previous layer.
+func (lb *LayerBase) Prev() Layer {
return lb.prevLayer
}
@@ -340,6 +341,8 @@ func (l *IPv4) ToBytes() ([]byte, error) {
fields.Protocol = uint8(header.TCPProtocolNumber)
case *UDP:
fields.Protocol = uint8(header.UDPProtocolNumber)
+ case *ICMPv4:
+ fields.Protocol = uint8(header.ICMPv4ProtocolNumber)
default:
// TODO(b/150301488): Support more protocols as needed.
return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
@@ -403,6 +406,8 @@ func parseIPv4(b []byte) (Layer, layerParser) {
nextParser = parseTCP
case header.UDPProtocolNumber:
nextParser = parseUDP
+ case header.ICMPv4ProtocolNumber:
+ nextParser = parseICMPv4
default:
// Assume that the rest is a payload.
nextParser = parsePayload
@@ -562,7 +567,7 @@ func (l *ICMPv6) ToBytes() ([]byte, error) {
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
} else {
- ipv6 := l.prev().(*IPv6)
+ ipv6 := l.Prev().(*IPv6)
h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{}))
}
return h, nil
@@ -606,6 +611,72 @@ func (l *ICMPv6) merge(other Layer) error {
return mergeLayer(l, other)
}
+// ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value
+// to store t and returns a pointer to it.
+func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type {
+ return &t
+}
+
+// ICMPv4 can construct and match an ICMPv4 encapsulation.
+type ICMPv4 struct {
+ LayerBase
+ Type *header.ICMPv4Type
+ Code *uint8
+ Checksum *uint16
+}
+
+func (l *ICMPv4) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *ICMPv4) ToBytes() ([]byte, error) {
+ b := make([]byte, header.ICMPv4MinimumSize)
+ h := header.ICMPv4(b)
+ if l.Type != nil {
+ h.SetType(*l.Type)
+ }
+ if l.Code != nil {
+ h.SetCode(byte(*l.Code))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ payload, err := payload(l)
+ if err != nil {
+ return nil, err
+ }
+ h.SetChecksum(header.ICMPv4Checksum(h, payload))
+ return h, nil
+}
+
+// parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a
+// parser for the encapsulated payload.
+func parseICMPv4(b []byte) (Layer, layerParser) {
+ h := header.ICMPv4(b)
+ icmpv4 := ICMPv4{
+ Type: ICMPv4Type(h.Type()),
+ Code: Uint8(h.Code()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &icmpv4, parsePayload
+}
+
+func (l *ICMPv4) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *ICMPv4) length() int {
+ return header.ICMPv4MinimumSize
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *ICMPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
// TCP can construct and match a TCP encapsulation.
type TCP struct {
LayerBase
@@ -676,25 +747,34 @@ func totalLength(l Layer) int {
return totalLength
}
+// payload returns a buffer.VectorisedView of l's payload.
+func payload(l Layer) (buffer.VectorisedView, error) {
+ var payloadBytes buffer.VectorisedView
+ for current := l.next(); current != nil; current = current.next() {
+ payload, err := current.ToBytes()
+ if err != nil {
+ return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload)
+ }
+ payloadBytes.AppendView(payload)
+ }
+ return payloadBytes, nil
+}
+
// layerChecksum calculates the checksum of the Layer header, including the
-// peusdeochecksum of the layer before it and all the bytes after it..
+// peusdeochecksum of the layer before it and all the bytes after it.
func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
totalLength := uint16(totalLength(l))
var xsum uint16
- switch s := l.prev().(type) {
+ switch s := l.Prev().(type) {
case *IPv4:
xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
default:
// TODO(b/150301488): Support more protocols, like IPv6.
return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
}
- var payloadBytes buffer.VectorisedView
- for current := l.next(); current != nil; current = current.next() {
- payload, err := current.ToBytes()
- if err != nil {
- return 0, fmt.Errorf("can't get bytes for next header: %s", payload)
- }
- payloadBytes.AppendView(payload)
+ payloadBytes, err := payload(l)
+ if err != nil {
+ return 0, err
}
xsum = header.ChecksumVV(payloadBytes, xsum)
return xsum, nil
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 42f87e3f3..6beccbfd0 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -29,6 +29,19 @@ packetimpact_go_test(
)
packetimpact_go_test(
+ name = "udp_icmp_error_propagation",
+ srcs = ["udp_icmp_error_propagation_test.go"],
+ # TODO(b/153926291): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
name = "tcp_window_shrink",
srcs = ["tcp_window_shrink_test.go"],
# TODO(b/153202472): Fix netstack then remove the line below.
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
new file mode 100644
index 000000000..9e4810842
--- /dev/null
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -0,0 +1,209 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_icmp_error_propagation_test
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+type connected bool
+
+func (c connected) String() string {
+ if c {
+ return "Connected"
+ }
+ return "Connectionless"
+}
+
+type icmpError int
+
+const (
+ portUnreachable icmpError = iota
+ timeToLiveExceeded
+)
+
+func (e icmpError) String() string {
+ switch e {
+ case portUnreachable:
+ return "PortUnreachable"
+ case timeToLiveExceeded:
+ return "TimeToLiveExpired"
+ }
+ return "Unknown ICMP error"
+}
+
+func (e icmpError) ToICMPv4() *tb.ICMPv4 {
+ switch e {
+ case portUnreachable:
+ return &tb.ICMPv4{Type: tb.ICMPv4Type(header.ICMPv4DstUnreachable), Code: tb.Uint8(header.ICMPv4PortUnreachable)}
+ case timeToLiveExceeded:
+ return &tb.ICMPv4{Type: tb.ICMPv4Type(header.ICMPv4TimeExceeded), Code: tb.Uint8(header.ICMPv4TTLExceeded)}
+ }
+ return nil
+}
+
+type errorDetectionFunc func(context.Context, *tb.DUT, *tb.UDPIPv4, int32, syscall.Errno) error
+
+// testRecv tests observing the ICMP error through the recv syscall.
+// A packet is sent to the DUT, and if wantErrno is non-zero, then the first
+// recv should fail and the second should succeed. Otherwise if wantErrno is
+// zero then the first recv should succeed immediately.
+func testRecv(ctx context.Context, dut *tb.DUT, conn *tb.UDPIPv4, remoteFD int32, wantErrno syscall.Errno) error {
+ conn.Send(tb.UDP{})
+
+ if wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(ctx, time.Second)
+ defer cancel()
+ ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0)
+ if ret != -1 {
+ return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
+ }
+ if err != wantErrno {
+ return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
+ }
+ }
+
+ dut.Recv(remoteFD, 100, 0)
+ return nil
+}
+
+// testSendTo tests observing the ICMP error through the send syscall.
+// If wantErrno is non-zero, the first send should fail and a subsequent send
+// should suceed; while if wantErrno is zero then the first send should just
+// succeed.
+func testSendTo(ctx context.Context, dut *tb.DUT, conn *tb.UDPIPv4, remoteFD int32, wantErrno syscall.Errno) error {
+ if wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(ctx, time.Second)
+ defer cancel()
+ ret, err := dut.SendToWithErrno(ctx, remoteFD, nil, 0, conn.LocalAddr())
+
+ if ret != -1 {
+ return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
+ }
+ if err != wantErrno {
+ return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
+ }
+ }
+
+ dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
+ if _, err := conn.Expect(tb.UDP{}, time.Second); err != nil {
+ return fmt.Errorf("did not receive UDP packet as expected: %s", err)
+ }
+ return nil
+}
+
+func testSockOpt(_ context.Context, dut *tb.DUT, conn *tb.UDPIPv4, remoteFD int32, wantErrno syscall.Errno) error {
+ errno := syscall.Errno(dut.GetSockOptInt(remoteFD, unix.SOL_SOCKET, unix.SO_ERROR))
+ if errno != wantErrno {
+ return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, wantErrno)
+ }
+
+ // Check that after clearing socket error, sending doesn't fail.
+ dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
+ if _, err := conn.Expect(tb.UDP{}, time.Second); err != nil {
+ return fmt.Errorf("did not receive UDP packet as expected: %s", err)
+ }
+ return nil
+}
+
+type testParameters struct {
+ connected connected
+ icmpErr icmpError
+ wantErrno syscall.Errno
+ f errorDetectionFunc
+ fName string
+}
+
+// TestUDPICMPErrorPropagation tests that ICMP PortUnreachable error messages
+// destined for a "connected" UDP socket are observable on said socket by:
+// 1. causing the next send to fail with ECONNREFUSED,
+// 2. causing the next recv to fail with ECONNREFUSED, or
+// 3. returning ECONNREFUSED through the SO_ERROR socket option.
+func TestUDPICMPErrorPropagation(t *testing.T) {
+ var testCases []testParameters
+ for _, c := range []connected{true, false} {
+ for _, i := range []icmpError{portUnreachable, timeToLiveExceeded} {
+ e := syscall.Errno(0)
+ if c && i == portUnreachable {
+ e = unix.ECONNREFUSED
+ }
+ for _, f := range []struct {
+ name string
+ f errorDetectionFunc
+ }{
+ {"SendTo", testSendTo},
+ {"Recv", testRecv},
+ {"SockOpt", testSockOpt},
+ } {
+ testCases = append(testCases, testParameters{c, i, e, f.f, f.name})
+ }
+ }
+ }
+
+ for _, tt := range testCases {
+ t.Run(fmt.Sprintf("%s/%s/%s", tt.connected, tt.icmpErr, tt.fName), func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+
+ remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(remoteFD)
+
+ conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ if tt.connected {
+ dut.Connect(remoteFD, conn.LocalAddr())
+ }
+
+ dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
+ udp, err := conn.Expect(tb.UDP{}, time.Second)
+ if err != nil {
+ t.Fatalf("did not receive message from DUT: %s", err)
+ }
+
+ if tt.icmpErr == timeToLiveExceeded {
+ ip, ok := udp.Prev().(*tb.IPv4)
+ if !ok {
+ t.Fatalf("expected %s to be IPv4", udp.Prev())
+ }
+ *ip.TTL = 1
+ // Let serialization recalculate the checksum since we set the
+ // TTL to 1.
+ ip.Checksum = nil
+
+ // Note that the ICMP payload is valid in this case because the UDP
+ // payload is empty. If the UDP payload were not empty, the packet
+ // length during serialization may not be calculated correctly,
+ // resulting in a mal-formed packet.
+ conn.SendIP(tt.icmpErr.ToICMPv4(), ip, udp)
+ } else {
+ conn.SendIP(tt.icmpErr.ToICMPv4(), udp.Prev(), udp)
+ }
+
+ if err := tt.f(context.Background(), &dut, &conn, remoteFD, tt.wantErrno); err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}