summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact/testbench
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r--test/packetimpact/testbench/BUILD5
-rw-r--r--test/packetimpact/testbench/connections.go612
-rw-r--r--test/packetimpact/testbench/dut.go390
-rw-r--r--test/packetimpact/testbench/layers.go329
-rw-r--r--test/packetimpact/testbench/layers_test.go221
-rw-r--r--test/packetimpact/testbench/rawsockets.go47
-rw-r--r--test/packetimpact/testbench/testbench.go81
7 files changed, 1330 insertions, 355 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 682933067..5a0ee1367 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -21,9 +21,10 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
"//pkg/usermem",
+ "//test/packetimpact/netdevs",
"//test/packetimpact/proto:posix_server_go_proto",
- "@com_github_google_go-cmp//cmp:go_default_library",
- "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
"@com_github_mohae_deepcopy//:go_default_library",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//keepalive:go_default_library",
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 463fd0556..3af5f83fd 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -41,38 +41,46 @@ func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
return 0, fmt.Errorf("sockaddr type %T does not contain port", sa)
}
-// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with
+// pickPort makes a new socket and returns the socket FD and port. The domain
+// should be AF_INET or AF_INET6. The caller must close the FD when done with
// the port if there is no error.
-func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
+func pickPort(domain, typ int) (fd int, port uint16, err error) {
fd, err = unix.Socket(domain, typ, 0)
if err != nil {
- return -1, nil, err
+ return -1, 0, fmt.Errorf("creating socket: %w", err)
}
defer func() {
if err != nil {
- err = multierr.Append(err, unix.Close(fd))
+ if cerr := unix.Close(fd); cerr != nil {
+ err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr))
+ }
}
}()
+ var sa unix.Sockaddr
switch domain {
case unix.AF_INET:
var sa4 unix.SockaddrInet4
copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4())
sa = &sa4
case unix.AF_INET6:
- var sa6 unix.SockaddrInet6
+ sa6 := unix.SockaddrInet6{ZoneId: uint32(LocalInterfaceID)}
copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16())
sa = &sa6
default:
- return -1, nil, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
+ return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
}
if err = unix.Bind(fd, sa); err != nil {
- return -1, nil, err
+ return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err)
}
sa, err = unix.Getsockname(fd)
if err != nil {
- return -1, nil, err
+ return -1, 0, fmt.Errorf("Getsocketname(%d): %w", fd, err)
+ }
+ port, err = portFromSockaddr(sa)
+ if err != nil {
+ return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err)
}
- return fd, sa, nil
+ return fd, port, nil
}
// layerState stores the state of a layer of a connection.
@@ -114,12 +122,12 @@ var _ layerState = (*etherState)(nil)
func newEtherState(out, in Ether) (*etherState, error) {
lMAC, err := tcpip.ParseMACAddress(LocalMAC)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("parsing local MAC: %q: %w", LocalMAC, err)
}
rMAC, err := tcpip.ParseMACAddress(RemoteMAC)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("parsing remote MAC: %q: %w", RemoteMAC, err)
}
s := etherState{
out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
@@ -267,11 +275,7 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value {
// newTCPState creates a new TCPState.
func newTCPState(domain int, out, in TCP) (*tcpState, error) {
- portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM)
- if err != nil {
- return nil, err
- }
- localPort, err := portFromSockaddr(localAddr)
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
if err != nil {
return nil, err
}
@@ -374,14 +378,10 @@ type udpState struct {
var _ layerState = (*udpState)(nil)
// newUDPState creates a new udpState.
-func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
- portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_DGRAM)
+func newUDPState(domain int, out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
if err != nil {
- return nil, nil, err
- }
- localPort, err := portFromSockaddr(localAddr)
- if err != nil {
- return nil, nil, err
+ return nil, fmt.Errorf("picking port: %w", err)
}
s := udpState{
out: UDP{SrcPort: &localPort},
@@ -389,12 +389,12 @@ func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
portPickerFD: portPickerFD,
}
if err := s.out.merge(&out); err != nil {
- return nil, nil, err
+ return nil, err
}
if err := s.in.merge(&in); err != nil {
- return nil, nil, err
+ return nil, err
}
- return &s, localAddr, nil
+ return &s, nil
}
func (s *udpState) outgoing() Layer {
@@ -429,8 +429,6 @@ type Connection struct {
layerStates []layerState
injector Injector
sniffer Sniffer
- localAddr unix.Sockaddr
- t *testing.T
}
// Returns the default incoming frame against which to match. If received is
@@ -463,7 +461,9 @@ func (conn *Connection) match(override, received Layers) bool {
}
// Close frees associated resources held by the Connection.
-func (conn *Connection) Close() {
+func (conn *Connection) Close(t *testing.T) {
+ t.Helper()
+
errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
for _, s := range conn.layerStates {
if err := s.close(); err != nil {
@@ -471,31 +471,62 @@ func (conn *Connection) Close() {
}
}
if errs != nil {
- conn.t.Fatalf("unable to close %+v: %s", conn, errs)
+ t.Fatalf("unable to close %+v: %s", conn, errs)
}
}
-// CreateFrame builds a frame for the connection with layer overriding defaults
-// of the innermost layer and additionalLayers added after it.
-func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+// CreateFrame builds a frame for the connection with defaults overriden
+// from the innermost layer out, and additionalLayers added after it.
+//
+// Note that overrideLayers can have a length that is less than the number
+// of layers in this connection, and in such cases the innermost layers are
+// overriden first. As an example, valid values of overrideLayers for a TCP-
+// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and
+// [Ethernet, IPv4, TCP].
+func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers {
+ t.Helper()
+
var layersToSend Layers
- for _, s := range conn.layerStates {
- layersToSend = append(layersToSend, s.outgoing())
- }
- if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil {
- conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err)
+ for i, s := range conn.layerStates {
+ layer := s.outgoing()
+ // overrideLayers and conn.layerStates have their tails aligned, so
+ // to find the index we move backwards by the distance i is to the
+ // end.
+ if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 {
+ if err := layer.merge(overrideLayers[j]); err != nil {
+ t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
+ }
+ }
+ layersToSend = append(layersToSend, layer)
}
layersToSend = append(layersToSend, additionalLayers...)
return layersToSend
}
+// SendFrameStateless sends a frame without updating any of the layer states.
+//
+// This method is useful for sending out-of-band control messages such as
+// ICMP packets, where it would not make sense to update the transport layer's
+// state using the ICMP header.
+func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) {
+ t.Helper()
+
+ outBytes, err := frame.ToBytes()
+ if err != nil {
+ t.Fatalf("can't build outgoing packet: %s", err)
+ }
+ conn.injector.Send(t, outBytes)
+}
+
// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *Connection) SendFrame(frame Layers) {
+func (conn *Connection) SendFrame(t *testing.T, frame Layers) {
+ t.Helper()
+
outBytes, err := frame.ToBytes()
if err != nil {
- conn.t.Fatalf("can't build outgoing packet: %s", err)
+ t.Fatalf("can't build outgoing packet: %s", err)
}
- conn.injector.Send(outBytes)
+ conn.injector.Send(t, outBytes)
// frame might have nil values where the caller wanted to use default values.
// sentFrame will have no nil values in it because it comes from parsing the
@@ -504,25 +535,32 @@ func (conn *Connection) SendFrame(frame Layers) {
// Update the state of each layer based on what was sent.
for i, s := range conn.layerStates {
if err := s.sent(sentFrame[i]); err != nil {
- conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
+ t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
}
}
}
-// Send a packet with reasonable defaults. Potentially override the final layer
-// in the connection with the provided layer and add additionLayers.
-func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {
- conn.SendFrame(conn.CreateFrame(layer, additionalLayers...))
+// send sends a packet, possibly with layers of this connection overridden and
+// additional layers added.
+//
+// Types defined with Connection as the underlying type should expose
+// type-safe versions of this method.
+func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
+ t.Helper()
+
+ conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...))
}
// recvFrame gets the next successfully parsed frame (of type Layers) within the
// timeout provided. If no parsable frame arrives before the timeout, it returns
// nil.
-func (conn *Connection) recvFrame(timeout time.Duration) Layers {
+func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers {
+ t.Helper()
+
if timeout <= 0 {
return nil
}
- b := conn.sniffer.Recv(timeout)
+ b := conn.sniffer.Recv(t, timeout)
if b == nil {
return nil
}
@@ -542,32 +580,36 @@ func (e *layersError) Error() string {
// Expect expects a frame with the final layerStates layer matching the
// provided Layer within the timeout specified. If it doesn't arrive in time,
// an error is returned.
-func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
+func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) {
+ t.Helper()
+
// Make a frame that will ignore all but the final layer.
layers := make([]Layer, len(conn.layerStates))
layers[len(layers)-1] = layer
- gotFrame, err := conn.ExpectFrame(layers, timeout)
+ gotFrame, err := conn.ExpectFrame(t, layers, timeout)
if err != nil {
return nil, err
}
if len(conn.layerStates)-1 < len(gotFrame) {
return gotFrame[len(conn.layerStates)-1], nil
}
- conn.t.Fatal("the received frame should be at least as long as the expected layers")
+ t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame)
panic("unreachable")
}
// ExpectFrame expects a frame that matches the provided Layers within the
// timeout specified. If one arrives in time, the Layers is returned without an
// error. If it doesn't arrive in time, it returns nil and error is non-nil.
-func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {
+func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
deadline := time.Now().Add(timeout)
var errs error
for {
var gotLayers Layers
if timeout = time.Until(deadline); timeout > 0 {
- gotLayers = conn.recvFrame(timeout)
+ gotLayers = conn.recvFrame(t, timeout)
}
if gotLayers == nil {
if errs == nil {
@@ -578,7 +620,7 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer
if conn.match(layers, gotLayers) {
for i, s := range conn.layerStates {
if err := s.received(gotLayers[i]); err != nil {
- conn.t.Fatal(err)
+ t.Fatalf("failed to update test connection's layer states based on received frame: %s", err)
}
}
return gotLayers, nil
@@ -589,8 +631,10 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer
// Drain drains the sniffer's receive buffer by receiving packets until there's
// nothing else to receive.
-func (conn *Connection) Drain() {
- conn.sniffer.Drain()
+func (conn *Connection) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
}
// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
@@ -598,6 +642,8 @@ type TCPIPv4 Connection
// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ t.Helper()
+
etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
@@ -623,84 +669,174 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
layerStates: []layerState{etherState, ipv4State, tcpState},
injector: injector,
sniffer: sniffer,
- t: t,
}
}
-// Handshake performs a TCP 3-way handshake. The input Connection should have a
+// Connect performs a TCP 3-way handshake. The input Connection should have a
// final TCP Layer.
-func (conn *TCPIPv4) Handshake() {
+func (conn *TCPIPv4) Connect(t *testing.T) {
+ t.Helper()
+
// Send the SYN.
- conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
- synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
- if synAck == nil {
- conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
- conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
+// The input Connection should have a final TCP Layer.
+func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) {
+ t.Helper()
+
+ // Send the SYN.
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
}
// ExpectData is a convenient method that expects a Layer and the Layer after
// it. If it doens't arrive in time, it returns nil.
-func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
+}
+
+// ExpectNextData attempts to receive the next incoming segment for the
+// connection and expects that to match the given layers.
+//
+// It differs from ExpectData() in that here we are only interested in the next
+// received segment, while ExpectData() can receive multiple segments for the
+// connection until there is a match with given layers or a timeout.
+func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ // Receive the first incoming TCP segment for this connection.
+ got, err := conn.ExpectData(t, &TCP{}, nil, timeout)
+ if err != nil {
+ return nil, err
+ }
+
expected := make([]Layer, len(conn.layerStates))
expected[len(expected)-1] = tcp
if payload != nil {
expected = append(expected, payload)
+ tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length()))
}
- return (*Connection)(conn).ExpectFrame(expected, timeout)
+ if !(*Connection)(conn).match(expected, got) {
+ return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got)
+ }
+ return got, nil
}
// Send a packet with reasonable defaults. Potentially override the TCP layer in
// the connection with the provided layer and add additionLayers.
-func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
- (*Connection)(conn).Send(&tcp, additionalLayers...)
+func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&tcp}, additionalLayers...)
}
// Close frees associated resources held by the TCPIPv4 connection.
-func (conn *TCPIPv4) Close() {
- (*Connection)(conn).Close()
+func (conn *TCPIPv4) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
// Expect expects a frame with the TCP layer matching the provided TCP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
-func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
- layer, err := (*Connection)(conn).Expect(&tcp, timeout)
+func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &tcp, timeout)
if layer == nil {
return nil, err
}
gotTCP, ok := layer.(*TCP)
if !ok {
- conn.t.Fatalf("expected %s to be TCP", layer)
+ t.Fatalf("expected %s to be TCP", layer)
}
return gotTCP, err
}
-func (conn *TCPIPv4) state() *tcpState {
- state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState {
+ t.Helper()
+
+ state, ok := conn.layerStates[2].(*tcpState)
+ if !ok {
+ t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State {
+ t.Helper()
+
+ state, ok := conn.layerStates[1].(*ipv4State)
if !ok {
- conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
}
return state
}
// RemoteSeqNum returns the next expected sequence number from the DUT.
-func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
- return conn.state().remoteSeqNum
+func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value {
+ t.Helper()
+
+ return conn.tcpState(t).remoteSeqNum
}
// LocalSeqNum returns the next sequence number to send from the testbench.
-func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
- return conn.state().localSeqNum
+func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value {
+ t.Helper()
+
+ return conn.tcpState(t).localSeqNum
}
// SynAck returns the SynAck that was part of the handshake.
-func (conn *TCPIPv4) SynAck() *TCP {
- return conn.state().synAck
+func (conn *TCPIPv4) SynAck(t *testing.T) *TCP {
+ t.Helper()
+
+ return conn.tcpState(t).synAck
+}
+
+// LocalAddr gets the local socket address of this connection.
+func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
+ return sa
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
}
// IPv6Conn maintains the state for all the layers in a IPv6 connection.
@@ -708,6 +844,8 @@ type IPv6Conn Connection
// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults.
func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
+ t.Helper()
+
etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make EtherState: %s", err)
@@ -730,36 +868,30 @@ func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
layerStates: []layerState{etherState, ipv6State},
injector: injector,
sniffer: sniffer,
- t: t,
}
}
-// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *IPv6Conn) SendFrame(frame Layers) {
- (*Connection)(conn).SendFrame(frame)
-}
+// Send sends a frame with ipv6 overriding the IPv6 layer defaults and
+// additionalLayers added after it.
+func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) {
+ t.Helper()
-// CreateFrame builds a frame for the connection with ipv6 overriding the ipv6
-// layer defaults and additionalLayers added after it.
-func (conn *IPv6Conn) CreateFrame(ipv6 IPv6, additionalLayers ...Layer) Layers {
- return (*Connection)(conn).CreateFrame(&ipv6, additionalLayers...)
+ (*Connection)(conn).send(t, Layers{&ipv6}, additionalLayers...)
}
// Close to clean up any resources held.
-func (conn *IPv6Conn) Close() {
- (*Connection)(conn).Close()
+func (conn *IPv6Conn) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
// ExpectFrame expects a frame that matches the provided Layers within the
// timeout specified. If it doesn't arrive in time, an error is returned.
-func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) {
- return (*Connection)(conn).ExpectFrame(frame, timeout)
-}
+func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
-// Drain drains the sniffer's receive buffer by receiving packets until there's
-// nothing else to receive.
-func (conn *TCPIPv4) Drain() {
- conn.sniffer.Drain()
+ return (*Connection)(conn).ExpectFrame(t, frame, timeout)
}
// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
@@ -767,6 +899,8 @@ type UDPIPv4 Connection
// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+ t.Helper()
+
etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
@@ -775,7 +909,7 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- udpState, localAddr, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
+ udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
if err != nil {
t.Fatalf("can't make udpState: %s", err)
}
@@ -792,78 +926,280 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
layerStates: []layerState{etherState, ipv4State, udpState},
injector: injector,
sniffer: sniffer,
- localAddr: localAddr,
- t: t,
}
}
+func (conn *UDPIPv4) udpState(t *testing.T) *udpState {
+ t.Helper()
+
+ state, ok := conn.layerStates[2].(*udpState)
+ if !ok {
+ t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State {
+ t.Helper()
+
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ }
+ return state
+}
+
// LocalAddr gets the local socket address of this connection.
-func (conn *UDPIPv4) LocalAddr() unix.Sockaddr {
- return conn.localAddr
+func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
+ return sa
}
-// CreateFrame builds a frame for the connection with layer overriding defaults
-// of the innermost layer and additionalLayers added after it.
-func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
- return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
+// Send sends a packet with reasonable defaults, potentially overriding the UDP
+// layer and adding additionLayers.
+func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...)
}
-// Send a packet with reasonable defaults. Potentially override the UDP layer in
-// the connection with the provided layer and add additionLayers.
-func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
- (*Connection)(conn).Send(&udp, additionalLayers...)
+// SendIP sends a packet with reasonable defaults, potentially overriding the
+// UDP and IPv4 headers and adding additionLayers.
+func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
}
-// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *UDPIPv4) SendFrame(frame Layers) {
- (*Connection)(conn).SendFrame(frame)
+// Expect expects a frame with the UDP layer matching the provided UDP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &udp, timeout)
+ if err != nil {
+ return nil, err
+ }
+ gotUDP, ok := layer.(*UDP)
+ if !ok {
+ t.Fatalf("expected %s to be UDP", layer)
+ }
+ return gotUDP, nil
}
-// SendIP sends a packet with additionalLayers following the IP layer in the
-// connection.
-func (conn *UDPIPv4) SendIP(additionalLayers ...Layer) {
- var layersToSend Layers
- for _, s := range conn.layerStates[:len(conn.layerStates)-1] {
- layersToSend = append(layersToSend, s.outgoing())
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = &udp
+ if payload.length() != 0 {
+ expected = append(expected, &payload)
}
- layersToSend = append(layersToSend, additionalLayers...)
- conn.SendFrame(layersToSend)
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
+}
+
+// Close frees associated resources held by the UDPIPv4 connection.
+func (conn *UDPIPv4) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv4) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
+}
+
+// UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection.
+type UDPIPv6 Connection
+
+// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults.
+func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 {
+ t.Helper()
+
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv6State, err := newIPv6State(IPv6{}, IPv6{})
+ if err != nil {
+ t.Fatalf("can't make IPv6State: %s", err)
+ }
+ udpState, err := newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP)
+ if err != nil {
+ t.Fatalf("can't make udpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+ return UDPIPv6{
+ layerStates: []layerState{etherState, ipv6State, udpState},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+func (conn *UDPIPv6) udpState(t *testing.T) *udpState {
+ t.Helper()
+
+ state, ok := conn.layerStates[2].(*udpState)
+ if !ok {
+ t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State {
+ t.Helper()
+
+ state, ok := conn.layerStates[1].(*ipv6State)
+ if !ok {
+ t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1])
+ }
+ return state
+}
+
+// LocalAddr gets the local socket address of this connection.
+func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet6{
+ Port: int(*conn.udpState(t).out.SrcPort),
+ // Local address is in perspective to the remote host, so it's scoped to the
+ // ID of the remote interface.
+ ZoneId: uint32(RemoteInterfaceID),
+ }
+ copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr)
+ return sa
+}
+
+// Send sends a packet with reasonable defaults, potentially overriding the UDP
+// layer and adding additionLayers.
+func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...)
+}
+
+// SendIPv6 sends a packet with reasonable defaults, potentially overriding the
+// UDP and IPv6 headers and adding additionLayers.
+func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
}
// Expect expects a frame with the UDP layer matching the provided UDP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
-func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
- conn.t.Helper()
- layer, err := (*Connection)(conn).Expect(&udp, timeout)
- if layer == nil {
+func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &udp, timeout)
+ if err != nil {
return nil, err
}
gotUDP, ok := layer.(*UDP)
if !ok {
- conn.t.Fatalf("expected %s to be UDP", layer)
+ t.Fatalf("expected %s to be UDP", layer)
}
- return gotUDP, err
+ return gotUDP, nil
}
// ExpectData is a convenient method that expects a Layer and the Layer after
// it. If it doens't arrive in time, it returns nil.
-func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
- conn.t.Helper()
+func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
expected := make([]Layer, len(conn.layerStates))
expected[len(expected)-1] = &udp
if payload.length() != 0 {
expected = append(expected, &payload)
}
- return (*Connection)(conn).ExpectFrame(expected, timeout)
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
}
-// Close frees associated resources held by the UDPIPv4 connection.
-func (conn *UDPIPv4) Close() {
- (*Connection)(conn).Close()
+// Close frees associated resources held by the UDPIPv6 connection.
+func (conn *UDPIPv6) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
// Drain drains the sniffer's receive buffer by receiving packets until there's
// nothing else to receive.
-func (conn *UDPIPv4) Drain() {
- conn.sniffer.Drain()
+func (conn *UDPIPv6) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
+}
+
+// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection.
+type TCPIPv6 Connection
+
+// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults.
+func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv6State, err := newIPv6State(IPv6{}, IPv6{})
+ if err != nil {
+ t.Fatalf("can't make ipv6State: %s", err)
+ }
+ tcpState, err := newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv6{
+ layerStates: []layerState{etherState, ipv6State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+func (conn *TCPIPv6) SrcPort() uint16 {
+ state := conn.layerStates[2].(*tcpState)
+ return *state.out.SrcPort
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
+}
+
+// Close frees associated resources held by the TCPIPv6 connection.
+func (conn *TCPIPv6) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index a78b7d7ee..73c532e75 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -16,6 +16,7 @@ package testbench
import (
"context"
+ "flag"
"net"
"strconv"
"syscall"
@@ -30,13 +31,19 @@ import (
// DUT communicates with the DUT to force it to make POSIX calls.
type DUT struct {
- t *testing.T
conn *grpc.ClientConn
posixServer POSIXClient
}
// NewDUT creates a new connection with the DUT over gRPC.
func NewDUT(t *testing.T) DUT {
+ t.Helper()
+
+ flag.Parse()
+ if err := genPseudoFlags(); err != nil {
+ t.Fatal("generating psuedo flags:", err)
+ }
+
posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort)
conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive}))
if err != nil {
@@ -44,7 +51,6 @@ func NewDUT(t *testing.T) DUT {
}
posixServer := NewPOSIXClient(conn)
return DUT{
- t: t,
conn: conn,
posixServer: posixServer,
}
@@ -55,8 +61,9 @@ func (dut *DUT) TearDown() {
dut.conn.Close()
}
-func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
- dut.t.Helper()
+func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr {
+ t.Helper()
+
switch s := sa.(type) {
case *unix.SockaddrInet4:
return &pb.Sockaddr{
@@ -81,12 +88,13 @@ func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
},
}
}
- dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ t.Fatalf("can't parse Sockaddr struct: %+v", sa)
return nil
}
-func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
- dut.t.Helper()
+func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr {
+ t.Helper()
+
switch s := sa.Sockaddr.(type) {
case *pb.Sockaddr_In:
ret := unix.SockaddrInet4{
@@ -100,31 +108,34 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
ZoneId: s.In6.GetScopeId(),
}
copy(ret.Addr[:], s.In6.GetAddr())
+ return &ret
}
- dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ t.Fatalf("can't parse Sockaddr proto: %#v", sa)
return nil
}
// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
// proto, and bound to the IP address addr. Returns the new file descriptor and
// the port that was selected on the DUT.
-func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
- dut.t.Helper()
+func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) {
+ t.Helper()
+
var fd int32
if addr.To4() != nil {
- fd = dut.Socket(unix.AF_INET, typ, proto)
+ fd = dut.Socket(t, unix.AF_INET, typ, proto)
sa := unix.SockaddrInet4{}
copy(sa.Addr[:], addr.To4())
- dut.Bind(fd, &sa)
+ dut.Bind(t, fd, &sa)
} else if addr.To16() != nil {
- fd = dut.Socket(unix.AF_INET6, typ, proto)
+ fd = dut.Socket(t, unix.AF_INET6, typ, proto)
sa := unix.SockaddrInet6{}
copy(sa.Addr[:], addr.To16())
- dut.Bind(fd, &sa)
+ sa.ZoneId = uint32(RemoteInterfaceID)
+ dut.Bind(t, fd, &sa)
} else {
- dut.t.Fatalf("unknown ip addr type for remoteIP")
+ t.Fatalf("invalid IP address: %s", addr)
}
- sa := dut.GetSockName(fd)
+ sa := dut.GetSockName(t, fd)
var port int
switch s := sa.(type) {
case *unix.SockaddrInet4:
@@ -132,15 +143,17 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16)
case *unix.SockaddrInet6:
port = s.Port
default:
- dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ t.Fatalf("unknown sockaddr type from getsockname: %T", sa)
}
return fd, uint16(port)
}
// CreateListener makes a new TCP connection. If it fails, the test ends.
-func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
- fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4))
- dut.Listen(fd, backlog)
+func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) {
+ t.Helper()
+
+ fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4))
+ dut.Listen(t, fd, backlog)
return fd, remotePort
}
@@ -150,53 +163,57 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// AcceptWithErrno.
-func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
- dut.t.Helper()
+func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
+ fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd)
if fd < 0 {
- dut.t.Fatalf("failed to accept: %s", err)
+ t.Fatalf("failed to accept: %s", err)
}
return fd, sa
}
// AcceptWithErrno calls accept on the DUT.
-func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
- dut.t.Helper()
+func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) {
+ t.Helper()
+
req := pb.AcceptRequest{
Sockfd: sockfd,
}
resp, err := dut.posixServer.Accept(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Accept: %s", err)
+ t.Fatalf("failed to call Accept: %s", err)
}
- return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is
// needed, use BindWithErrno.
-func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
- dut.t.Helper()
+func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.BindWithErrno(ctx, fd, sa)
+ ret, err := dut.BindWithErrno(ctx, t, fd, sa)
if ret != 0 {
- dut.t.Fatalf("failed to bind socket: %s", err)
+ t.Fatalf("failed to bind socket: %s", err)
}
}
// BindWithErrno calls bind on the DUT.
-func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) {
+ t.Helper()
+
req := pb.BindRequest{
Sockfd: fd,
- Addr: dut.sockaddrToProto(sa),
+ Addr: dut.sockaddrToProto(t, sa),
}
resp, err := dut.posixServer.Bind(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Bind: %s", err)
+ t.Fatalf("failed to call Bind: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -204,25 +221,27 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (
// Close calls close on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// CloseWithErrno.
-func (dut *DUT) Close(fd int32) {
- dut.t.Helper()
+func (dut *DUT) Close(t *testing.T, fd int32) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.CloseWithErrno(ctx, fd)
+ ret, err := dut.CloseWithErrno(ctx, t, fd)
if ret != 0 {
- dut.t.Fatalf("failed to close: %s", err)
+ t.Fatalf("failed to close: %s", err)
}
}
// CloseWithErrno calls close on the DUT.
-func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) {
+ t.Helper()
+
req := pb.CloseRequest{
Fd: fd,
}
resp, err := dut.posixServer.Close(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Close: %s", err)
+ t.Fatalf("failed to call Close: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -230,26 +249,61 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
// Connect calls connect on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use ConnectWithErrno.
-func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) {
- dut.t.Helper()
+func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.ConnectWithErrno(ctx, fd, sa)
- if ret != 0 {
- dut.t.Fatalf("failed to connect socket: %s", err)
+ ret, err := dut.ConnectWithErrno(ctx, t, fd, sa)
+ // Ignore 'operation in progress' error that can be returned when the socket
+ // is non-blocking.
+ if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 {
+ t.Fatalf("failed to connect socket: %s", err)
}
}
// ConnectWithErrno calls bind on the DUT.
-func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) {
+ t.Helper()
+
req := pb.ConnectRequest{
Sockfd: fd,
- Addr: dut.sockaddrToProto(sa),
+ Addr: dut.sockaddrToProto(t, sa),
}
resp, err := dut.posixServer.Connect(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Connect: %s", err)
+ t.Fatalf("failed to call Connect: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Fcntl calls fcntl on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use FcntlWithErrno.
+func (dut *DUT) Fcntl(t *testing.T, fd, cmd, arg int32) int32 {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.FcntlWithErrno(ctx, t, fd, cmd, arg)
+ if ret == -1 {
+ t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err)
+ }
+ return ret
+}
+
+// FcntlWithErrno calls fcntl on the DUT.
+func (dut *DUT) FcntlWithErrno(ctx context.Context, t *testing.T, fd, cmd, arg int32) (int32, error) {
+ t.Helper()
+
+ req := pb.FcntlRequest{
+ Fd: fd,
+ Cmd: cmd,
+ Arg: arg,
+ }
+ resp, err := dut.posixServer.Fcntl(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Fcntl: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -257,32 +311,35 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr
// GetSockName calls getsockname on the DUT and causes a fatal test failure if
// it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockNameWithErrno.
-func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
- dut.t.Helper()
+func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd)
+ ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd)
if ret != 0 {
- dut.t.Fatalf("failed to getsockname: %s", err)
+ t.Fatalf("failed to getsockname: %s", err)
}
return sa
}
// GetSockNameWithErrno calls getsockname on the DUT.
-func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
- dut.t.Helper()
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) {
+ t.Helper()
+
req := pb.GetSockNameRequest{
Sockfd: sockfd,
}
resp, err := dut.posixServer.GetSockName(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Bind: %s", err)
+ t.Fatalf("failed to call Bind: %s", err)
}
- return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
-func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) {
- dut.t.Helper()
+func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) {
+ t.Helper()
+
req := pb.GetSockOptRequest{
Sockfd: sockfd,
Level: level,
@@ -292,11 +349,11 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i
}
resp, err := dut.posixServer.GetSockOpt(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call GetSockOpt: %s", err)
+ t.Fatalf("failed to call GetSockOpt: %s", err)
}
optval := resp.GetOptval()
if optval == nil {
- dut.t.Fatalf("GetSockOpt response does not contain a value")
+ t.Fatalf("GetSockOpt response does not contain a value")
}
return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_())
}
@@ -306,13 +363,14 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i
// needed, use GetSockOptWithErrno. Because endianess and the width of values
// might differ between the testbench and DUT architectures, prefer to use a
// more specific GetSockOptXxx function.
-func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
- dut.t.Helper()
+func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen)
+ ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen)
if ret != 0 {
- dut.t.Fatalf("failed to GetSockOpt: %s", err)
+ t.Fatalf("failed to GetSockOpt: %s", err)
}
return optval
}
@@ -320,12 +378,13 @@ func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the
// width of values might differ between the testbench and DUT architectures,
// prefer to use a more specific GetSockOptXxxWithErrno function.
-func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) {
- dut.t.Helper()
- ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES)
+func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES)
bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval)
if !ok {
- dut.t.Fatalf("GetSockOpt got value type: %T, want bytes", optval)
+ t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val)
}
return ret, bytesval.Bytesval, errno
}
@@ -333,24 +392,26 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname,
// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the int optval or error handling
// is needed, use GetSockOptIntWithErrno.
-func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 {
- dut.t.Helper()
+func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname)
+ ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname)
if ret != 0 {
- dut.t.Fatalf("failed to GetSockOptInt: %s", err)
+ t.Fatalf("failed to GetSockOptInt: %s", err)
}
return intval
}
// GetSockOptIntWithErrno calls getsockopt with an integer optval.
-func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) {
- dut.t.Helper()
- ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_INT)
+func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT)
intval, ok := optval.Val.(*pb.SockOptVal_Intval)
if !ok {
- dut.t.Fatalf("GetSockOpt got value type: %T, want int", optval)
+ t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val)
}
return ret, intval.Intval, errno
}
@@ -358,24 +419,26 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna
// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockOptTimevalWithErrno.
-func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval {
- dut.t.Helper()
+func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname)
+ ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname)
if ret != 0 {
- dut.t.Fatalf("failed to GetSockOptTimeval: %s", err)
+ t.Fatalf("failed to GetSockOptTimeval: %s", err)
}
return timeval
}
// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval.
-func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) {
- dut.t.Helper()
- ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME)
+func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME)
tv, ok := optval.Val.(*pb.SockOptVal_Timeval)
if !ok {
- dut.t.Fatalf("GetSockOpt got value type: %T, want timeval", optval)
+ t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val)
}
timeval := unix.Timeval{
Sec: tv.Timeval.Seconds,
@@ -387,26 +450,28 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, o
// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// ListenWithErrno.
-func (dut *DUT) Listen(sockfd, backlog int32) {
- dut.t.Helper()
+func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
+ ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog)
if ret != 0 {
- dut.t.Fatalf("failed to listen: %s", err)
+ t.Fatalf("failed to listen: %s", err)
}
}
// ListenWithErrno calls listen on the DUT.
-func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) {
+ t.Helper()
+
req := pb.ListenRequest{
Sockfd: sockfd,
Backlog: backlog,
}
resp, err := dut.posixServer.Listen(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Listen: %s", err)
+ t.Fatalf("failed to call Listen: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -414,20 +479,22 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int
// Send calls send on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// SendWithErrno.
-func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
- dut.t.Helper()
+func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
+ ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags)
if ret == -1 {
- dut.t.Fatalf("failed to send: %s", err)
+ t.Fatalf("failed to send: %s", err)
}
return ret
}
// SendWithErrno calls send on the DUT.
-func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) {
+ t.Helper()
+
req := pb.SendRequest{
Sockfd: sockfd,
Buf: buf,
@@ -435,7 +502,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla
}
resp, err := dut.posixServer.Send(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Send: %s", err)
+ t.Fatalf("failed to call Send: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -443,35 +510,52 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla
// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// SendToWithErrno.
-func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
- dut.t.Helper()
+func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr)
+ ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr)
if ret == -1 {
- dut.t.Fatalf("failed to sendto: %s", err)
+ t.Fatalf("failed to sendto: %s", err)
}
return ret
}
// SendToWithErrno calls sendto on the DUT.
-func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
+ t.Helper()
+
req := pb.SendToRequest{
Sockfd: sockfd,
Buf: buf,
Flags: flags,
- DestAddr: dut.sockaddrToProto(destAddr),
+ DestAddr: dut.sockaddrToProto(t, destAddr),
}
resp, err := dut.posixServer.SendTo(ctx, &req)
if err != nil {
- dut.t.Fatalf("faled to call SendTo: %s", err)
+ t.Fatalf("faled to call SendTo: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) {
- dut.t.Helper()
+// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking
+// is true, otherwise it will clear the flag.
+func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) {
+ t.Helper()
+
+ flags := dut.Fcntl(t, fd, unix.F_GETFL, 0)
+ if nonblocking {
+ flags |= unix.O_NONBLOCK
+ } else {
+ flags &= ^unix.O_NONBLOCK
+ }
+ dut.Fcntl(t, fd, unix.F_SETFL, flags)
+}
+
+func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) {
+ t.Helper()
+
req := pb.SetSockOptRequest{
Sockfd: sockfd,
Level: level,
@@ -480,7 +564,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op
}
resp, err := dut.posixServer.SetSockOpt(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call SetSockOpt: %s", err)
+ t.Fatalf("failed to call SetSockOpt: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -490,81 +574,89 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op
// needed, use SetSockOptWithErrno. Because endianess and the width of values
// might differ between the testbench and DUT architectures, prefer to use a
// more specific SetSockOptXxx function.
-func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
- dut.t.Helper()
+func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ t.Fatalf("failed to SetSockOpt: %s", err)
}
}
// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
// width of values might differ between the testbench and DUT architectures,
// prefer to use a more specific SetSockOptXxxWithErrno function.
-func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) {
- dut.t.Helper()
- return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}})
+func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) {
+ t.Helper()
+
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}})
}
// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the int optval or error handling
// is needed, use SetSockOptIntWithErrno.
-func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
- dut.t.Helper()
+func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
+ ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOptInt: %s", err)
+ t.Fatalf("failed to SetSockOptInt: %s", err)
}
}
// SetSockOptIntWithErrno calls setsockopt with an integer optval.
-func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) {
- dut.t.Helper()
- return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}})
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) {
+ t.Helper()
+
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}})
}
// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the timeout or error handling is
// needed, use SetSockOptTimevalWithErrno.
-func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
- dut.t.Helper()
+func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv)
+ ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ t.Fatalf("failed to SetSockOptTimeval: %s", err)
}
}
// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
// bytes.
-func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+ t.Helper()
+
timeval := pb.Timeval{
Seconds: int64(tv.Sec),
Microseconds: int64(tv.Usec),
}
- return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}})
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}})
}
// Socket calls socket on the DUT and returns the file descriptor. If socket
// fails on the DUT, the test ends.
-func (dut *DUT) Socket(domain, typ, proto int32) int32 {
- dut.t.Helper()
- fd, err := dut.SocketWithErrno(domain, typ, proto)
+func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 {
+ t.Helper()
+
+ fd, err := dut.SocketWithErrno(t, domain, typ, proto)
if fd < 0 {
- dut.t.Fatalf("failed to create socket: %s", err)
+ t.Fatalf("failed to create socket: %s", err)
}
return fd
}
// SocketWithErrno calls socket on the DUT and returns the fd and errno.
-func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) {
+ t.Helper()
+
req := pb.SocketRequest{
Domain: domain,
Type: typ,
@@ -573,7 +665,7 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
ctx := context.Background()
resp, err := dut.posixServer.Socket(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Socket: %s", err)
+ t.Fatalf("failed to call Socket: %s", err)
}
return resp.GetFd(), syscall.Errno(resp.GetErrno_())
}
@@ -581,20 +673,22 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// RecvWithErrno.
-func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
- dut.t.Helper()
+func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
+ ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags)
if ret == -1 {
- dut.t.Fatalf("failed to recv: %s", err)
+ t.Fatalf("failed to recv: %s", err)
}
return buf
}
// RecvWithErrno calls recv on the DUT.
-func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
- dut.t.Helper()
+func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) {
+ t.Helper()
+
req := pb.RecvRequest{
Sockfd: sockfd,
Len: len,
@@ -602,7 +696,7 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (in
}
resp, err := dut.posixServer.Recv(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Recv: %s", err)
+ t.Fatalf("failed to call Recv: %s", err)
}
return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 1b0e5b8fc..24aa46cce 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -15,6 +15,7 @@
package testbench
import (
+ "encoding/binary"
"encoding/hex"
"fmt"
"reflect"
@@ -470,17 +471,11 @@ func (l *IPv6) ToBytes() ([]byte, error) {
if l.NextHeader != nil {
fields.NextHeader = *l.NextHeader
} else {
- switch n := l.next().(type) {
- case *TCP:
- fields.NextHeader = uint8(header.TCPProtocolNumber)
- case *UDP:
- fields.NextHeader = uint8(header.UDPProtocolNumber)
- case *ICMPv6:
- fields.NextHeader = uint8(header.ICMPv6ProtocolNumber)
- default:
- // TODO(b/150301488): Support more protocols as needed.
- return nil, fmt.Errorf("ToBytes can't deduce the IPv6 header's next protocol: %#v", n)
+ nh, err := nextHeaderByLayer(l.next())
+ if err != nil {
+ return nil, err
}
+ fields.NextHeader = nh
}
if l.HopLimit != nil {
fields.HopLimit = *l.HopLimit
@@ -495,6 +490,27 @@ func (l *IPv6) ToBytes() ([]byte, error) {
return h, nil
}
+// nextIPv6PayloadParser finds the corresponding parser for nextHeader.
+func nextIPv6PayloadParser(nextHeader uint8) layerParser {
+ switch tcpip.TransportProtocolNumber(nextHeader) {
+ case header.TCPProtocolNumber:
+ return parseTCP
+ case header.UDPProtocolNumber:
+ return parseUDP
+ case header.ICMPv6ProtocolNumber:
+ return parseICMPv6
+ }
+ switch header.IPv6ExtensionHeaderIdentifier(nextHeader) {
+ case header.IPv6HopByHopOptionsExtHdrIdentifier:
+ return parseIPv6HopByHopOptionsExtHdr
+ case header.IPv6DestinationOptionsExtHdrIdentifier:
+ return parseIPv6DestinationOptionsExtHdr
+ case header.IPv6FragmentExtHdrIdentifier:
+ return parseIPv6FragmentExtHdr
+ }
+ return parsePayload
+}
+
// parseIPv6 parses the bytes assuming that they start with an ipv6 header and
// continues parsing further encapsulations.
func parseIPv6(b []byte) (Layer, layerParser) {
@@ -509,18 +525,7 @@ func parseIPv6(b []byte) (Layer, layerParser) {
SrcAddr: Address(h.SourceAddress()),
DstAddr: Address(h.DestinationAddress()),
}
- var nextParser layerParser
- switch h.TransportProtocol() {
- case header.TCPProtocolNumber:
- nextParser = parseTCP
- case header.UDPProtocolNumber:
- nextParser = parseUDP
- case header.ICMPv6ProtocolNumber:
- nextParser = parseICMPv6
- default:
- // Assume that the rest is a payload.
- nextParser = parsePayload
- }
+ nextParser := nextIPv6PayloadParser(h.NextHeader())
return &ipv6, nextParser
}
@@ -538,13 +543,241 @@ func (l *IPv6) merge(other Layer) error {
return mergeLayer(l, other)
}
+// IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions
+// Extension Header.
+type IPv6HopByHopOptionsExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ Options []byte
+}
+
+// IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions
+// Extension Header.
+type IPv6DestinationOptionsExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ Options []byte
+}
+
+// IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header.
+type IPv6FragmentExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ FragmentOffset *uint16
+ MoreFragments *bool
+ Identification *uint32
+}
+
+// nextHeaderByLayer finds the correct next header protocol value for layer l.
+func nextHeaderByLayer(l Layer) (uint8, error) {
+ if l == nil {
+ return uint8(header.IPv6NoNextHeaderIdentifier), nil
+ }
+ switch l.(type) {
+ case *TCP:
+ return uint8(header.TCPProtocolNumber), nil
+ case *UDP:
+ return uint8(header.UDPProtocolNumber), nil
+ case *ICMPv6:
+ return uint8(header.ICMPv6ProtocolNumber), nil
+ case *Payload:
+ return uint8(header.IPv6NoNextHeaderIdentifier), nil
+ case *IPv6HopByHopOptionsExtHdr:
+ return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil
+ case *IPv6DestinationOptionsExtHdr:
+ return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil
+ case *IPv6FragmentExtHdr:
+ return uint8(header.IPv6FragmentExtHdrIdentifier), nil
+ default:
+ // TODO(b/161005083): Support more protocols as needed.
+ return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l)
+ }
+}
+
+// ipv6OptionsExtHdrToBytes serializes an options extension header into bytes.
+func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) {
+ length := len(options) + 2
+ if length%8 != 0 {
+ return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options))
+ }
+ bytes := make([]byte, length)
+ if nextHeader != nil {
+ bytes[0] = byte(*nextHeader)
+ } else {
+ nh, err := nextHeaderByLayer(nextLayer)
+ if err != nil {
+ return nil, err
+ }
+ bytes[0] = nh
+ }
+ // ExtHdrLen field is the length of the extension header
+ // in 8-octet unit, ignoring the first 8 octets.
+ // https://tools.ietf.org/html/rfc2460#section-4.3
+ // https://tools.ietf.org/html/rfc2460#section-4.6
+ bytes[1] = uint8((length - 8) / 8)
+ copy(bytes[2:], options)
+ return bytes, nil
+}
+
+// IPv6ExtHdrIdent is a helper routine that allocates a new
+// header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer
+// to it.
+func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier {
+ return &id
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) {
+ return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) {
+ return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) {
+ var offset, mflag uint16
+ var ident uint32
+ bytes := make([]byte, header.IPv6FragmentExtHdrLength)
+ if l.NextHeader != nil {
+ bytes[0] = byte(*l.NextHeader)
+ } else {
+ nh, err := nextHeaderByLayer(l.next())
+ if err != nil {
+ return nil, err
+ }
+ bytes[0] = nh
+ }
+ bytes[1] = 0 // reserved
+ if l.MoreFragments != nil && *l.MoreFragments {
+ mflag = 1
+ }
+ if l.FragmentOffset != nil {
+ offset = *l.FragmentOffset
+ }
+ if l.Identification != nil {
+ ident = *l.Identification
+ }
+ offsetAndMflag := offset<<3 | mflag
+ binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag)
+ binary.BigEndian.PutUint32(bytes[4:], ident)
+
+ return bytes, nil
+}
+
+// parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader
+// field, the rest of the payload and a parser function for the corresponding
+// next extension header.
+func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) {
+ nextHeader := b[0]
+ // For HopByHop and Destination options extension headers,
+ // This field is the length of the extension header in
+ // 8-octet units, not including the first 8 octets.
+ // https://tools.ietf.org/html/rfc2460#section-4.3
+ // https://tools.ietf.org/html/rfc2460#section-4.6
+ length := b[1]*8 + 8
+ data := b[2:length]
+ nextParser := nextIPv6PayloadParser(nextHeader)
+ return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser
+}
+
+// parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start
+// with an IPv6 HopByHop Options Extension Header.
+func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader, options, nextParser := parseIPv6ExtHdr(b)
+ return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
+}
+
+// parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start
+// with an IPv6 Destination Options Extension Header.
+func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader, options, nextParser := parseIPv6ExtHdr(b)
+ return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
+}
+
+// Bool is a helper routine that allocates a new
+// bool value to store v and returns a pointer to it.
+func Bool(v bool) *bool {
+ return &v
+}
+
+// parseIPv6FragmentExtHdr parses the bytes assuming that they start
+// with an IPv6 Fragment Extension Header.
+func parseIPv6FragmentExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader := b[0]
+ var extHdr header.IPv6FragmentExtHdr
+ copy(extHdr[:], b[2:])
+ return &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)),
+ FragmentOffset: Uint16(extHdr.FragmentOffset()),
+ MoreFragments: Bool(extHdr.More()),
+ Identification: Uint32(extHdr.ID()),
+ }, nextIPv6PayloadParser(nextHeader)
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) length() int {
+ return len(l.Options) + 2
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) String() string {
+ return stringLayer(l)
+}
+
+func (l *IPv6DestinationOptionsExtHdr) length() int {
+ return len(l.Options) + 2
+}
+
+func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6DestinationOptionsExtHdr) String() string {
+ return stringLayer(l)
+}
+
+func (*IPv6FragmentExtHdr) length() int {
+ return header.IPv6FragmentExtHdrLength
+}
+
+func (l *IPv6FragmentExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6FragmentExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6FragmentExtHdr) String() string {
+ return stringLayer(l)
+}
+
// ICMPv6 can construct and match an ICMPv6 encapsulation.
type ICMPv6 struct {
LayerBase
- Type *header.ICMPv6Type
- Code *byte
- Checksum *uint16
- NDPPayload []byte
+ Type *header.ICMPv6Type
+ Code *byte
+ Checksum *uint16
+ Payload []byte
}
func (l *ICMPv6) String() string {
@@ -555,7 +788,7 @@ func (l *ICMPv6) String() string {
// ToBytes implements Layer.ToBytes.
func (l *ICMPv6) ToBytes() ([]byte, error) {
- b := make([]byte, header.ICMPv6HeaderSize+len(l.NDPPayload))
+ b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload))
h := header.ICMPv6(b)
if l.Type != nil {
h.SetType(*l.Type)
@@ -563,12 +796,23 @@ func (l *ICMPv6) ToBytes() ([]byte, error) {
if l.Code != nil {
h.SetCode(*l.Code)
}
- copy(h.NDPPayload(), l.NDPPayload)
+ copy(h.NDPPayload(), l.Payload)
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
} else {
- ipv6 := l.Prev().(*IPv6)
- h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{}))
+ // It is possible that the ICMPv6 header does not follow the IPv6 header
+ // immediately, there could be one or more extension headers in between.
+ // We need to search forward to find the IPv6 header.
+ for prev := l.Prev(); prev != nil; prev = prev.Prev() {
+ if ipv6, ok := prev.(*IPv6); ok {
+ payload, err := payload(l)
+ if err != nil {
+ return nil, err
+ }
+ h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, payload))
+ break
+ }
+ }
}
return h, nil
}
@@ -589,10 +833,10 @@ func Byte(v byte) *byte {
func parseICMPv6(b []byte) (Layer, layerParser) {
h := header.ICMPv6(b)
icmpv6 := ICMPv6{
- Type: ICMPv6Type(h.Type()),
- Code: Byte(h.Code()),
- Checksum: Uint16(h.Checksum()),
- NDPPayload: h.NDPPayload(),
+ Type: ICMPv6Type(h.Type()),
+ Code: Byte(h.Code()),
+ Checksum: Uint16(h.Checksum()),
+ Payload: h.NDPPayload(),
}
return &icmpv6, nil
}
@@ -602,7 +846,7 @@ func (l *ICMPv6) match(other Layer) bool {
}
func (l *ICMPv6) length() int {
- return header.ICMPv6HeaderSize + len(l.NDPPayload)
+ return header.ICMPv6HeaderSize + len(l.Payload)
}
// merge overrides the values in l with the values from other but only in fields
@@ -768,12 +1012,14 @@ func payload(l Layer) (buffer.VectorisedView, error) {
func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
totalLength := uint16(totalLength(l))
var xsum uint16
- switch s := l.Prev().(type) {
+ switch p := l.Prev().(type) {
case *IPv4:
- xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
+ xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength)
+ case *IPv6:
+ xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength)
default:
- // TODO(b/150301488): Support more protocols, like IPv6.
- return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
+ // TODO(b/161246171): Support more protocols.
+ return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p)
}
payloadBytes, err := payload(l)
if err != nil {
@@ -939,6 +1185,11 @@ func (l *Payload) ToBytes() ([]byte, error) {
return l.Bytes, nil
}
+// Length returns payload byte length.
+func (l *Payload) Length() int {
+ return l.length()
+}
+
func (l *Payload) match(other Layer) bool {
return equalLayer(l, other)
}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
index c7f00e70d..a2a763034 100644
--- a/test/packetimpact/testbench/layers_test.go
+++ b/test/packetimpact/testbench/layers_test.go
@@ -505,3 +505,224 @@ func TestTCPOptions(t *testing.T) {
})
}
}
+
+func TestIPv6ExtHdrOptions(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ wantBytes []byte
+ wantLayers Layers
+ }{
+ {
+ description: "IPv6/HopByHop",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &Payload{
+ Bytes: nil,
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Payload",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Sample Data
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &Payload{
+ Bytes: []byte("Sample Data"),
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Destination/ICMPv6",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Destination Options
+ 0x3a, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // ICMPv6 Param Problem
+ 0x04, 0x00, 0x5f, 0x98, 0x00, 0x00, 0x00, 0x06,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6DestinationOptionsExtHdrIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &IPv6DestinationOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &ICMPv6{
+ Type: ICMPv6Type(header.ICMPv6ParamProblem),
+ Code: Byte(0),
+ Checksum: Uint16(0x5f98),
+ Payload: []byte{0x00, 0x00, 0x00, 0x06},
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Fragment",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Fragment ExtHdr
+ 0x3b, 0x00, 0x03, 0x20, 0x00, 0x00, 0x00, 0x2a,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ FragmentOffset: Uint16(100),
+ MoreFragments: Bool(false),
+ Identification: Uint32(42),
+ },
+ &Payload{
+ Bytes: nil,
+ },
+ },
+ },
+ {
+ description: "IPv6/DestOpt/Fragment/Payload",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x1b, 0x3c, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // Destination Options
+ 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Fragment ExtHdr
+ 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a,
+ // Sample Data
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6DestinationOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ FragmentOffset: Uint16(100),
+ MoreFragments: Bool(true),
+ Identification: Uint32(42),
+ },
+ &Payload{
+ Bytes: []byte("Sample Data"),
+ },
+ },
+ },
+ {
+ description: "IPv6/Fragment/Payload",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x2c, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // Fragment ExtHdr
+ 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a,
+ // Sample Data
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ FragmentOffset: Uint16(100),
+ MoreFragments: Bool(true),
+ Identification: Uint32(42),
+ },
+ &Payload{
+ Bytes: []byte("Sample Data"),
+ },
+ },
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ layers := parse(parseIPv6, tt.wantBytes)
+ if !layers.match(tt.wantLayers) {
+ t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers))
+ }
+ // Make sure we can generate correct next header values and checksums
+ for _, layer := range layers {
+ switch layer := layer.(type) {
+ case *IPv6HopByHopOptionsExtHdr:
+ layer.NextHeader = nil
+ case *IPv6DestinationOptionsExtHdr:
+ layer.NextHeader = nil
+ case *IPv6FragmentExtHdr:
+ layer.NextHeader = nil
+ case *ICMPv6:
+ layer.Checksum = nil
+ }
+ }
+ gotBytes, err := layers.ToBytes()
+ if err != nil {
+ t.Fatalf("ToBytes() failed on %s: %s", &layers, err)
+ }
+ if !bytes.Equal(tt.wantBytes, gotBytes) {
+ t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 4665f60b2..57e822725 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -16,7 +16,6 @@ package testbench
import (
"encoding/binary"
- "flag"
"fmt"
"math"
"net"
@@ -29,7 +28,6 @@ import (
// Sniffer can sniff raw packets on the wire.
type Sniffer struct {
- t *testing.T
fd int
}
@@ -41,7 +39,8 @@ func htons(x uint16) uint16 {
// NewSniffer creates a Sniffer connected to *device.
func NewSniffer(t *testing.T) (Sniffer, error) {
- flag.Parse()
+ t.Helper()
+
snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
if err != nil {
return Sniffer{}, err
@@ -53,7 +52,6 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
}
return Sniffer{
- t: t,
fd: snifferFd,
}, nil
}
@@ -63,7 +61,9 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
const maxReadSize int = 65536
// Recv tries to read one frame until the timeout is up.
-func (s *Sniffer) Recv(timeout time.Duration) []byte {
+func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte {
+ t.Helper()
+
deadline := time.Now().Add(timeout)
for {
timeout = deadline.Sub(time.Now())
@@ -77,7 +77,7 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
}
if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
- s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
+ t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
}
buf := make([]byte, maxReadSize)
@@ -87,10 +87,10 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
continue
}
if err != nil {
- s.t.Fatalf("can't read: %s", err)
+ t.Fatalf("can't read: %s", err)
}
if nread > maxReadSize {
- s.t.Fatalf("received a truncated frame of %d bytes", nread)
+ t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize)
}
return buf[:nread]
}
@@ -98,14 +98,16 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
// Drain drains the Sniffer's socket receive buffer by receiving until there's
// nothing else to receive.
-func (s *Sniffer) Drain() {
- s.t.Helper()
+func (s *Sniffer) Drain(t *testing.T) {
+ t.Helper()
+
flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0)
if err != nil {
- s.t.Fatalf("failed to get sniffer socket fd flags: %s", err)
+ t.Fatalf("failed to get sniffer socket fd flags: %s", err)
}
- if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil {
- s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err)
+ nonBlockingFlags := flags | unix.O_NONBLOCK
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil {
+ t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err)
}
for {
buf := make([]byte, maxReadSize)
@@ -115,7 +117,7 @@ func (s *Sniffer) Drain() {
}
}
if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil {
- s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err)
+ t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err)
}
}
@@ -130,13 +132,13 @@ func (s *Sniffer) close() error {
// Injector can inject raw frames.
type Injector struct {
- t *testing.T
fd int
}
// NewInjector creates a new injector on *device.
func NewInjector(t *testing.T) (Injector, error) {
- flag.Parse()
+ t.Helper()
+
ifInfo, err := net.InterfaceByName(Device)
if err != nil {
return Injector{}, err
@@ -159,15 +161,20 @@ func NewInjector(t *testing.T) (Injector, error) {
return Injector{}, err
}
return Injector{
- t: t,
fd: injectFd,
}, nil
}
// Send a raw frame.
-func (i *Injector) Send(b []byte) {
- if _, err := unix.Write(i.fd, b); err != nil {
- i.t.Fatalf("can't write: %s of len %d", err, len(b))
+func (i *Injector) Send(t *testing.T, b []byte) {
+ t.Helper()
+
+ n, err := unix.Write(i.fd, b)
+ if err != nil {
+ t.Fatalf("can't write bytes of len %d: %s", len(b), err)
+ }
+ if n != len(b) {
+ t.Fatalf("got %d bytes written, want %d", n, len(b))
}
}
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
index a1242b189..e3629e1f3 100644
--- a/test/packetimpact/testbench/testbench.go
+++ b/test/packetimpact/testbench/testbench.go
@@ -16,29 +16,52 @@ package testbench
import (
"flag"
+ "fmt"
+ "math/rand"
+ "net"
+ "os/exec"
+ "testing"
"time"
+
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
)
var (
+ // Native indicates that the test is being run natively.
+ Native = false
// Device is the local device on the test network.
Device = ""
+
// LocalIPv4 is the local IPv4 address on the test network.
LocalIPv4 = ""
+ // RemoteIPv4 is the DUT's IPv4 address on the test network.
+ RemoteIPv4 = ""
+ // IPv4PrefixLength is the network prefix length of the IPv4 test network.
+ IPv4PrefixLength = 0
+
// LocalIPv6 is the local IPv6 address on the test network.
LocalIPv6 = ""
+ // RemoteIPv6 is the DUT's IPv6 address on the test network.
+ RemoteIPv6 = ""
+
+ // LocalInterfaceID is the ID of the local interface on the test network.
+ LocalInterfaceID uint32
+ // RemoteInterfaceID is the ID of the remote interface on the test network.
+ //
+ // Not using uint32 because package flag does not support uint32.
+ RemoteInterfaceID uint64
+
// LocalMAC is the local MAC address on the test network.
LocalMAC = ""
+ // RemoteMAC is the DUT's MAC address on the test network.
+ RemoteMAC = ""
+
// POSIXServerIP is the POSIX server's IP address on the control network.
POSIXServerIP = ""
// POSIXServerPort is the UDP port the POSIX server is bound to on the
// control network.
POSIXServerPort = 40000
- // RemoteIPv4 is the DUT's IPv4 address on the test network.
- RemoteIPv4 = ""
- // RemoteIPv6 is the DUT's IPv6 address on the test network.
- RemoteIPv6 = ""
- // RemoteMAC is the DUT's MAC address on the test network.
- RemoteMAC = ""
+
// RPCKeepalive is the gRPC keepalive.
RPCKeepalive = 10 * time.Second
// RPCTimeout is the gRPC timeout.
@@ -55,9 +78,51 @@ func RegisterFlags(fs *flag.FlagSet) {
fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets")
fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets")
- fs.StringVar(&LocalIPv6, "local_ipv6", LocalIPv6, "local IPv6 address for test packets")
fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets")
- fs.StringVar(&LocalMAC, "local_mac", LocalMAC, "local mac address for test packets")
fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets")
fs.StringVar(&Device, "device", Device, "local device for test packets")
+ fs.BoolVar(&Native, "native", Native, "whether the test is running natively")
+ fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets")
+}
+
+// genPseudoFlags populates flag-like global config based on real flags.
+//
+// genPseudoFlags must only be called after flag.Parse.
+func genPseudoFlags() error {
+ out, err := exec.Command("ip", "addr", "show").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("listing devices: %q: %w", string(out), err)
+ }
+ devs, err := netdevs.ParseDevices(string(out))
+ if err != nil {
+ return fmt.Errorf("parsing devices: %w", err)
+ }
+
+ _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(LocalIPv4), devs)
+ if err != nil {
+ return fmt.Errorf("can't find deviceInfo: %w", err)
+ }
+
+ LocalMAC = deviceInfo.MAC.String()
+ LocalIPv6 = deviceInfo.IPv6Addr.String()
+ LocalInterfaceID = deviceInfo.ID
+
+ if deviceInfo.IPv4Net != nil {
+ IPv4PrefixLength, _ = deviceInfo.IPv4Net.Mask.Size()
+ } else {
+ IPv4PrefixLength, _ = net.ParseIP(LocalIPv4).DefaultMask().Size()
+ }
+
+ return nil
+}
+
+// GenerateRandomPayload generates a random byte slice of the specified length,
+// causing a fatal test failure if it is unable to do so.
+func GenerateRandomPayload(t *testing.T, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return buf
}