summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact/testbench
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r--test/packetimpact/testbench/BUILD12
-rw-r--r--test/packetimpact/testbench/connections.go663
-rw-r--r--test/packetimpact/testbench/dut.go362
-rw-r--r--test/packetimpact/testbench/layers.go339
-rw-r--r--test/packetimpact/testbench/layers_test.go156
-rw-r--r--test/packetimpact/testbench/rawsockets.go44
6 files changed, 1256 insertions, 320 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index a34c81fcc..b6a254882 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(
default_visibility = ["//test/packetimpact:__subpackages__"],
@@ -16,6 +16,7 @@ go_library(
],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
"//pkg/usermem",
@@ -27,5 +28,14 @@ go_library(
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//keepalive:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
+ "@org_uber_go_multierr//:go_default_library",
],
)
+
+go_test(
+ name = "testbench_test",
+ size = "small",
+ srcs = ["layers_test.go"],
+ library = ":testbench",
+ deps = ["//pkg/tcpip"],
+)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index b7aa63934..f84fd8ba7 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -21,10 +21,12 @@ import (
"fmt"
"math/rand"
"net"
+ "strings"
"testing"
"time"
"github.com/mohae/deepcopy"
+ "go.uber.org/multierr"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -36,19 +38,6 @@ var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test pa
var localMAC = flag.String("local_mac", "", "local mac address for test packets")
var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
-// TCPIPv4 maintains state about a TCP/IPv4 connection.
-type TCPIPv4 struct {
- outgoing Layers
- incoming Layers
- LocalSeqNum seqnum.Value
- RemoteSeqNum seqnum.Value
- SynAck *TCP
- sniffer Sniffer
- injector Injector
- portPickerFD int
- t *testing.T
-}
-
// pickPort makes a new socket and returns the socket FD and port. The caller
// must close the FD when done with the port if there is no error.
func pickPort() (int, uint16, error) {
@@ -75,171 +64,607 @@ func pickPort() (int, uint16, error) {
return fd, uint16(newSockAddrInet4.Port), nil
}
-// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It
-// is the third, after Ethernet and IPv4.
-const tcpLayerIndex int = 2
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame.
+ outgoing() Layer
-// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
-func NewTCPIPv4(t *testing.T, dut DUT, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
+}
+
+// etherState maintains state about an Ethernet connection.
+type etherState struct {
+ out, in Ether
+}
+
+var _ layerState = (*etherState)(nil)
+
+// newEtherState creates a new etherState.
+func newEtherState(out, in Ether) (*etherState, error) {
lMAC, err := tcpip.ParseMACAddress(*localMAC)
if err != nil {
- t.Fatalf("can't parse localMAC %q: %s", *localMAC, err)
+ return nil, err
}
rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
if err != nil {
- t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err)
+ return nil, err
}
-
- portPickerFD, localPort, err := pickPort()
- if err != nil {
- t.Fatalf("can't pick a port: %s", err)
+ s := etherState{
+ out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
+ in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
}
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *etherState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *etherState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*etherState) sent(Layer) error {
+ return nil
+}
+
+func (*etherState) received(Layer) error {
+ return nil
+}
+
+func (*etherState) close() error {
+ return nil
+}
+
+// ipv4State maintains state about an IPv4 connection.
+type ipv4State struct {
+ out, in IPv4
+}
+
+var _ layerState = (*ipv4State)(nil)
+
+// newIPv4State creates a new ipv4State.
+func newIPv4State(out, in IPv4) (*ipv4State, error) {
lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
+ s := ipv4State{
+ out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
- sniffer, err := NewSniffer(t)
+func (s *ipv4State) outgoing() Layer {
+ return &s.out
+}
+
+func (s *ipv4State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*ipv4State) sent(Layer) error {
+ return nil
+}
+
+func (*ipv4State) received(Layer) error {
+ return nil
+}
+
+func (*ipv4State) close() error {
+ return nil
+}
+
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+
+var _ layerState = (*tcpState)(nil)
+
+// SeqNumValue is a helper routine that allocates a new seqnum.Value value to
+// store v and returns a pointer to it.
+func SeqNumValue(v seqnum.Value) *seqnum.Value {
+ return &v
+}
+
+// newTCPState creates a new TCPState.
+func newTCPState(out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort()
if err != nil {
- t.Fatalf("can't make new sniffer: %s", err)
+ return nil, err
+ }
+ s := tcpState{
+ out: TCP{SrcPort: &localPort},
+ in: TCP{DstPort: &localPort},
+ localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())),
+ portPickerFD: portPickerFD,
+ finSent: false,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
}
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
- injector, err := NewInjector(t)
- if err != nil {
- t.Fatalf("can't make new injector: %s", err)
+func (s *tcpState) outgoing() Layer {
+ newOutgoing := deepcopy.Copy(s.out).(TCP)
+ if s.localSeqNum != nil {
+ newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum))
+ }
+ if s.remoteSeqNum != nil {
+ newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ return &newOutgoing
+}
+
+func (s *tcpState) incoming(received Layer) Layer {
+ tcpReceived, ok := received.(*TCP)
+ if !ok {
+ return nil
+ }
+ newIn := deepcopy.Copy(s.in).(TCP)
+ if s.remoteSeqNum != nil {
+ newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ // The caller didn't specify an AckNum so we'll expect the calculated one,
+ // but only if the ACK flag is set because the AckNum is not valid in a
+ // header if ACK is not set.
+ newIn.AckNum = Uint32(uint32(*s.localSeqNum))
}
+ return &newIn
+}
- newOutgoingTCP := &TCP{
- DataOffset: Uint8(header.TCPMinimumSize),
- WindowSize: Uint16(32768),
- SrcPort: &localPort,
+func (s *tcpState) sent(sent Layer) error {
+ tcp, ok := sent.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", sent)
+ }
+ if !s.finSent {
+ // update localSeqNum by the payload only when FIN is not yet sent by us
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.localSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
}
- if err := newOutgoingTCP.merge(outgoingTCP); err != nil {
- t.Fatalf("can't merge %v into %v: %s", outgoingTCP, newOutgoingTCP, err)
+ if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.localSeqNum.UpdateForward(1)
}
- newIncomingTCP := &TCP{
- DstPort: &localPort,
+ if *tcp.Flags&(header.TCPFlagFin) != 0 {
+ s.finSent = true
}
- if err := newIncomingTCP.merge(incomingTCP); err != nil {
- t.Fatalf("can't merge %v into %v: %s", incomingTCP, newIncomingTCP, err)
+ return nil
+}
+
+func (s *tcpState) received(l Layer) error {
+ tcp, ok := l.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", l)
}
- return TCPIPv4{
- outgoing: Layers{
- &Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
- &IPv4{SrcAddr: &lIP, DstAddr: &rIP},
- newOutgoingTCP},
- incoming: Layers{
- &Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
- &IPv4{SrcAddr: &rIP, DstAddr: &lIP},
- newIncomingTCP},
- sniffer: sniffer,
- injector: injector,
+ s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum))
+ if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.remoteSeqNum.UpdateForward(1)
+ }
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.remoteSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *tcpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// udpState maintains state about a UDP connection.
+type udpState struct {
+ out, in UDP
+ portPickerFD int
+}
+
+var _ layerState = (*udpState)(nil)
+
+// newUDPState creates a new udpState.
+func newUDPState(out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort()
+ if err != nil {
+ return nil, err
+ }
+ s := udpState{
+ out: UDP{SrcPort: &localPort},
+ in: UDP{DstPort: &localPort},
portPickerFD: portPickerFD,
- t: t,
- LocalSeqNum: seqnum.Value(rand.Uint32()),
}
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
}
-// Close the injector and sniffer associated with this connection.
-func (conn *TCPIPv4) Close() {
- conn.sniffer.Close()
- conn.injector.Close()
- if err := unix.Close(conn.portPickerFD); err != nil {
- conn.t.Fatalf("can't close portPickerFD: %s", err)
+func (s *udpState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *udpState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*udpState) sent(l Layer) error {
+ return nil
+}
+
+func (*udpState) received(l Layer) error {
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *udpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
}
- conn.portPickerFD = -1
+ s.portPickerFD = -1
+ return nil
}
-// Send a packet with reasonable defaults and override some fields by tcp.
-func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
- if tcp.SeqNum == nil {
- tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum))
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+ t *testing.T
+}
+
+// match tries to match each Layer in received against the incoming filter. If
+// received is longer than layerStates then that may still count as a match. The
+// reverse is never a match. override overrides the default matchers for each
+// Layer.
+func (conn *Connection) match(override, received Layers) bool {
+ if len(received) < len(conn.layerStates) {
+ return false
+ }
+ for i, s := range conn.layerStates {
+ toMatch := s.incoming(received[i])
+ if toMatch == nil {
+ return false
+ }
+ if i < len(override) {
+ toMatch.merge(override[i])
+ }
+ if !toMatch.match(received[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close() {
+ errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
+ for _, s := range conn.layerStates {
+ if err := s.close(); err != nil {
+ errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err))
+ }
}
- if tcp.AckNum == nil {
- tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum))
+ if errs != nil {
+ conn.t.Fatalf("unable to close %+v: %s", conn, errs)
}
- layersToSend := deepcopy.Copy(conn.outgoing).(Layers)
- if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil {
- conn.t.Fatalf("can't merge %v into %v: %s", tcp, layersToSend[tcpLayerIndex], err)
+}
+
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ var layersToSend Layers
+ for _, s := range conn.layerStates {
+ layersToSend = append(layersToSend, s.outgoing())
+ }
+ if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil {
+ conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err)
}
layersToSend = append(layersToSend, additionalLayers...)
- outBytes, err := layersToSend.toBytes()
+ return layersToSend
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(frame Layers) {
+ outBytes, err := frame.toBytes()
if err != nil {
conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
}
conn.injector.Send(outBytes)
- // Compute the next TCP sequence number.
- for i := tcpLayerIndex + 1; i < len(layersToSend); i++ {
- conn.LocalSeqNum.UpdateForward(seqnum.Size(layersToSend[i].length()))
+ // frame might have nil values where the caller wanted to use default values.
+ // sentFrame will have no nil values in it because it comes from parsing the
+ // bytes that were actually sent.
+ sentFrame := parse(parseEther, outBytes)
+ // Update the state of each layer based on what was sent.
+ for i, s := range conn.layerStates {
+ if err := s.sent(sentFrame[i]); err != nil {
+ conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
+ }
}
- if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
- conn.LocalSeqNum.UpdateForward(1)
+}
+
+// Send a packet with reasonable defaults. Potentially override the final layer
+// in the connection with the provided layer and add additionLayers.
+func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {
+ conn.SendFrame(conn.CreateFrame(layer, additionalLayers...))
+}
+
+// recvFrame gets the next successfully parsed frame (of type Layers) within the
+// timeout provided. If no parsable frame arrives before the timeout, it returns
+// nil.
+func (conn *Connection) recvFrame(timeout time.Duration) Layers {
+ if timeout <= 0 {
+ return nil
}
+ b := conn.sniffer.Recv(timeout)
+ if b == nil {
+ return nil
+ }
+ return parse(parseEther, b)
}
-// Recv gets a packet from the sniffer within the timeout provided. If no packet
-// arrives before the timeout, it returns nil.
-func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
- deadline := time.Now().Add(timeout)
- for {
- timeout = deadline.Sub(time.Now())
- if timeout <= 0 {
- break
- }
- b := conn.sniffer.Recv(timeout)
- if b == nil {
- break
- }
- layers, err := ParseEther(b)
- if err != nil {
- continue // Ignore packets that can't be parsed.
- }
- if !conn.incoming.match(layers) {
- continue // Ignore packets that don't match the expected incoming.
- }
- tcpHeader := (layers[tcpLayerIndex]).(*TCP)
- conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum)
- if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
- conn.RemoteSeqNum.UpdateForward(1)
- }
- for i := tcpLayerIndex + 1; i < len(layers); i++ {
- conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length()))
- }
- return tcpHeader
+// Expect a frame with the final layerStates layer matching the provided Layer
+// within the timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
+ // Make a frame that will ignore all but the final layer.
+ layers := make([]Layer, len(conn.layerStates))
+ layers[len(layers)-1] = layer
+
+ gotFrame, err := conn.ExpectFrame(layers, timeout)
+ if err != nil {
+ return nil, err
}
- return nil
+ if len(conn.layerStates)-1 < len(gotFrame) {
+ return gotFrame[len(conn.layerStates)-1], nil
+ }
+ conn.t.Fatal("the received frame should be at least as long as the expected layers")
+ return nil, fmt.Errorf("the received frame should be at least as long as the expected layers")
}
-// Expect a packet that matches the provided tcp within the timeout specified.
-// If it doesn't arrive in time, the test fails.
-func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP {
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {
deadline := time.Now().Add(timeout)
+ var allLayers []string
for {
- timeout = deadline.Sub(time.Now())
- if timeout <= 0 {
- return nil
+ var gotLayers Layers
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotLayers = conn.recvFrame(timeout)
}
- gotTCP := conn.Recv(timeout)
- if gotTCP == nil {
- return nil
+ if gotLayers == nil {
+ return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n"))
}
- if tcp.match(gotTCP) {
- return gotTCP
+ if conn.match(layers, gotLayers) {
+ for i, s := range conn.layerStates {
+ if err := s.received(gotLayers[i]); err != nil {
+ conn.t.Fatal(err)
+ }
+ }
+ return gotLayers, nil
}
+ allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers))
+ }
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain() {
+ conn.sniffer.Drain()
+}
+
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+
+// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
+func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newTCPState(outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
}
}
-// Handshake performs a TCP 3-way handshake.
+// Handshake performs a TCP 3-way handshake. The input Connection should have a
+// final TCP Layer.
func (conn *TCPIPv4) Handshake() {
// Send the SYN.
conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
- conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
- if conn.SynAck == nil {
- conn.t.Fatalf("didn't get synack during handshake")
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if synAck == nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
}
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(expected, timeout)
+}
+
+// Send a packet with reasonable defaults. Potentially override the TCP layer in
+// the connection with the provided layer and add additionLayers.
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ (*Connection)(conn).Send(&tcp, additionalLayers...)
+}
+
+// Close frees associated resources held by the TCPIPv4 connection.
+func (conn *TCPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Expect a frame with the TCP layer matching the provided TCP within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
+ layer, err := (*Connection)(conn).Expect(&tcp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotTCP, ok := layer.(*TCP)
+ if !ok {
+ conn.t.Fatalf("expected %s to be TCP", layer)
+ }
+ return gotTCP, err
+}
+
+func (conn *TCPIPv4) state() *tcpState {
+ state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ }
+ return state
+}
+
+// RemoteSeqNum returns the next expected sequence number from the DUT.
+func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
+ return conn.state().remoteSeqNum
+}
+
+// LocalSeqNum returns the next sequence number to send from the testbench.
+func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
+ return conn.state().localSeqNum
+}
+
+// SynAck returns the SynAck that was part of the handshake.
+func (conn *TCPIPv4) SynAck() *TCP {
+ return conn.state().synAck
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
+
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
+
+// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
+func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newUDPState(outgoingUDP, incomingUDP)
+ if err != nil {
+ t.Fatalf("can't make udpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return UDPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
+ }
+}
+
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *UDPIPv4) SendFrame(frame Layers) {
+ (*Connection)(conn).SendFrame(frame)
+}
+
+// Close frees associated resources held by the UDPIPv4 connection.
+func (conn *UDPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index 8ea1706d3..9335909c0 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -65,33 +65,6 @@ func (dut *DUT) TearDown() {
dut.conn.Close()
}
-// SocketWithErrno calls socket on the DUT and returns the fd and errno.
-func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
- dut.t.Helper()
- req := pb.SocketRequest{
- Domain: domain,
- Type: typ,
- Protocol: proto,
- }
- ctx := context.Background()
- resp, err := dut.posixServer.Socket(ctx, &req)
- if err != nil {
- dut.t.Fatalf("failed to call Socket: %s", err)
- }
- return resp.GetFd(), syscall.Errno(resp.GetErrno_())
-}
-
-// Socket calls socket on the DUT and returns the file descriptor. If socket
-// fails on the DUT, the test ends.
-func (dut *DUT) Socket(domain, typ, proto int32) int32 {
- dut.t.Helper()
- fd, err := dut.SocketWithErrno(domain, typ, proto)
- if fd < 0 {
- dut.t.Fatalf("failed to create socket: %s", err)
- }
- return fd
-}
-
func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
dut.t.Helper()
switch s := sa.(type) {
@@ -142,14 +115,95 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
return nil
}
+// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
+// proto, and bound to the IP address addr. Returns the new file descriptor and
+// the port that was selected on the DUT.
+func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
+ dut.t.Helper()
+ var fd int32
+ if addr.To4() != nil {
+ fd = dut.Socket(unix.AF_INET, typ, proto)
+ sa := unix.SockaddrInet4{}
+ copy(sa.Addr[:], addr.To4())
+ dut.Bind(fd, &sa)
+ } else if addr.To16() != nil {
+ fd = dut.Socket(unix.AF_INET6, typ, proto)
+ sa := unix.SockaddrInet6{}
+ copy(sa.Addr[:], addr.To16())
+ dut.Bind(fd, &sa)
+ } else {
+ dut.t.Fatal("unknown ip addr type for remoteIP")
+ }
+ sa := dut.GetSockName(fd)
+ var port int
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ port = s.Port
+ case *unix.SockaddrInet6:
+ port = s.Port
+ default:
+ dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ }
+ return fd, uint16(port)
+}
+
+// CreateListener makes a new TCP connection. If it fails, the test ends.
+func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+ fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4))
+ dut.Listen(fd, backlog)
+ return fd, remotePort
+}
+
+// All the functions that make gRPC calls to the Posix service are below, sorted
+// alphabetically.
+
+// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// AcceptWithErrno.
+func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
+ if fd < 0 {
+ dut.t.Fatalf("failed to accept: %s", err)
+ }
+ return fd, sa
+}
+
+// AcceptWithErrno calls accept on the DUT.
+func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.AcceptRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.Accept(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Accept: %s", err)
+ }
+ return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is
+// needed, use BindWithErrno.
+func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.BindWithErrno(ctx, fd, sa)
+ if ret != 0 {
+ dut.t.Fatalf("failed to bind socket: %s", err)
+ }
+}
+
// BindWithErrno calls bind on the DUT.
-func (dut *DUT) BindWithErrno(fd int32, sa unix.Sockaddr) (int32, error) {
+func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
dut.t.Helper()
req := pb.BindRequest{
Sockfd: fd,
Addr: dut.sockaddrToProto(sa),
}
- ctx := context.Background()
resp, err := dut.posixServer.Bind(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call Bind: %s", err)
@@ -157,23 +211,52 @@ func (dut *DUT) BindWithErrno(fd int32, sa unix.Sockaddr) (int32, error) {
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// CloseWithErrno.
+func (dut *DUT) Close(fd int32) {
dut.t.Helper()
- ret, err := dut.BindWithErrno(fd, sa)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.CloseWithErrno(ctx, fd)
if ret != 0 {
- dut.t.Fatalf("failed to bind socket: %s", err)
+ dut.t.Fatalf("failed to close: %s", err)
}
}
+// CloseWithErrno calls close on the DUT.
+func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.CloseRequest{
+ Fd: fd,
+ }
+ resp, err := dut.posixServer.Close(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Close: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockName calls getsockname on the DUT and causes a fatal test failure if
+// it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockNameWithErrno.
+func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to getsockname: %s", err)
+ }
+ return sa
+}
+
// GetSockNameWithErrno calls getsockname on the DUT.
-func (dut *DUT) GetSockNameWithErrno(sockfd int32) (int32, unix.Sockaddr, error) {
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
dut.t.Helper()
req := pb.GetSockNameRequest{
Sockfd: sockfd,
}
- ctx := context.Background()
resp, err := dut.posixServer.GetSockName(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call Bind: %s", err)
@@ -181,26 +264,26 @@ func (dut *DUT) GetSockNameWithErrno(sockfd int32) (int32, unix.Sockaddr, error)
return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
-// GetSockName calls getsockname on the DUT and causes a fatal test failure if
-// it doens't succeed.
-func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
+// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// ListenWithErrno.
+func (dut *DUT) Listen(sockfd, backlog int32) {
dut.t.Helper()
- ret, sa, err := dut.GetSockNameWithErrno(sockfd)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
if ret != 0 {
- dut.t.Fatalf("failed to getsockname: %s", err)
+ dut.t.Fatalf("failed to listen: %s", err)
}
- return sa
}
// ListenWithErrno calls listen on the DUT.
-func (dut *DUT) ListenWithErrno(sockfd, backlog int32) (int32, error) {
+func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) {
dut.t.Helper()
req := pb.ListenRequest{
Sockfd: sockfd,
Backlog: backlog,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
resp, err := dut.posixServer.Listen(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call Listen: %s", err)
@@ -208,44 +291,54 @@ func (dut *DUT) ListenWithErrno(sockfd, backlog int32) (int32, error) {
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Listen(sockfd, backlog int32) {
+// Send calls send on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendWithErrno.
+func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
dut.t.Helper()
- ret, err := dut.ListenWithErrno(sockfd, backlog)
- if ret != 0 {
- dut.t.Fatalf("failed to listen: %s", err)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to send: %s", err)
}
+ return ret
}
-// AcceptWithErrno calls accept on the DUT.
-func (dut *DUT) AcceptWithErrno(sockfd int32) (int32, unix.Sockaddr, error) {
+// SendWithErrno calls send on the DUT.
+func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) {
dut.t.Helper()
- req := pb.AcceptRequest{
+ req := pb.SendRequest{
Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
- resp, err := dut.posixServer.Accept(ctx, &req)
+ resp, err := dut.posixServer.Send(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Accept: %s", err)
+ dut.t.Fatalf("failed to call Send: %s", err)
}
- return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific SetSockOptXxx function.
+func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
dut.t.Helper()
- fd, sa, err := dut.AcceptWithErrno(sockfd)
- if fd < 0 {
- dut.t.Fatalf("failed to accept: %s", err)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOpt: %s", err)
}
- return fd, sa
}
-// SetSockOptWithErrno calls setsockopt on the DUT.
-func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte) (int32, error) {
+// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific SetSockOptXxxWithErrno function.
+func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) {
dut.t.Helper()
req := pb.SetSockOptRequest{
Sockfd: sockfd,
@@ -253,8 +346,6 @@ func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte)
Optname: optname,
Optval: optval,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
resp, err := dut.posixServer.SetSockOpt(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call SetSockOpt: %s", err)
@@ -262,19 +353,51 @@ func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte)
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
-// doesn't succeed.
-func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
+// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use SetSockOptIntWithErrno.
+func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
dut.t.Helper()
- ret, err := dut.SetSockOptWithErrno(sockfd, level, optname, optval)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ dut.t.Fatalf("failed to SetSockOptInt: %s", err)
+ }
+}
+
+// SetSockOptIntWithErrno calls setsockopt with an integer optval.
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SetSockOptIntRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Intval: optval,
+ }
+ resp, err := dut.posixServer.SetSockOptInt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOptInt: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptTimevalWithErrno.
+func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
}
}
// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
// bytes.
-func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
dut.t.Helper()
timeval := pb.Timeval{
Seconds: int64(tv.Sec),
@@ -286,8 +409,6 @@ func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *uni
Optname: optname,
Timeval: &timeval,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err)
@@ -295,69 +416,58 @@ func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *uni
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
-// if it doesn't succeed.
-func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
+// Socket calls socket on the DUT and returns the file descriptor. If socket
+// fails on the DUT, the test ends.
+func (dut *DUT) Socket(domain, typ, proto int32) int32 {
dut.t.Helper()
- ret, err := dut.SetSockOptTimevalWithErrno(sockfd, level, optname, tv)
- if ret != 0 {
- dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ fd, err := dut.SocketWithErrno(domain, typ, proto)
+ if fd < 0 {
+ dut.t.Fatalf("failed to create socket: %s", err)
}
+ return fd
}
-// CloseWithErrno calls close on the DUT.
-func (dut *DUT) CloseWithErrno(fd int32) (int32, error) {
+// SocketWithErrno calls socket on the DUT and returns the fd and errno.
+func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
dut.t.Helper()
- req := pb.CloseRequest{
- Fd: fd,
+ req := pb.SocketRequest{
+ Domain: domain,
+ Type: typ,
+ Protocol: proto,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
- resp, err := dut.posixServer.Close(ctx, &req)
+ ctx := context.Background()
+ resp, err := dut.posixServer.Socket(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Close: %s", err)
+ dut.t.Fatalf("failed to call Socket: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetFd(), syscall.Errno(resp.GetErrno_())
}
-// Close calls close on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Close(fd int32) {
+// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// RecvWithErrno.
+func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
dut.t.Helper()
- ret, err := dut.CloseWithErrno(fd)
- if ret != 0 {
- dut.t.Fatalf("failed to close: %s", err)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to recv: %s", err)
}
+ return buf
}
-// CreateListener makes a new TCP connection. If it fails, the test ends.
-func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+// RecvWithErrno calls recv on the DUT.
+func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
dut.t.Helper()
- addr := net.ParseIP(*remoteIPv4)
- var fd int32
- if addr.To4() != nil {
- fd = dut.Socket(unix.AF_INET, typ, proto)
- sa := unix.SockaddrInet4{}
- copy(sa.Addr[:], addr.To4())
- dut.Bind(fd, &sa)
- } else if addr.To16() != nil {
- fd = dut.Socket(unix.AF_INET6, typ, proto)
- sa := unix.SockaddrInet6{}
- copy(sa.Addr[:], addr.To16())
- dut.Bind(fd, &sa)
- } else {
- dut.t.Fatal("unknown ip addr type for remoteIP")
+ req := pb.RecvRequest{
+ Sockfd: sockfd,
+ Len: len,
+ Flags: flags,
}
- sa := dut.GetSockName(fd)
- var port int
- switch s := sa.(type) {
- case *unix.SockaddrInet4:
- port = s.Port
- case *unix.SockaddrInet6:
- port = s.Port
- default:
- dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ resp, err := dut.posixServer.Recv(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Recv: %s", err)
}
- dut.Listen(fd, backlog)
- return fd, uint16(port)
+ return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 35fa4dcb6..5ce324f0d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -15,13 +15,16 @@
package testbench
import (
+ "encoding/hex"
"fmt"
"reflect"
+ "strings"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/imdario/mergo"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -31,6 +34,8 @@ import (
// Layer contains all the fields of the encapsulation. Each field is a pointer
// and may be nil.
type Layer interface {
+ fmt.Stringer
+
// toBytes converts the Layer into bytes. In places where the Layer's field
// isn't nil, the value that is pointed to is used. When the field is nil, a
// reasonable default for the Layer is used. For example, "64" for IPv4 TTL
@@ -42,7 +47,8 @@ type Layer interface {
// match checks if the current Layer matches the provided Layer. If either
// Layer has a nil in a given field, that field is considered matching.
- // Otherwise, the values pointed to by the fields must match.
+ // Otherwise, the values pointed to by the fields must match. The LayerBase is
+ // ignored.
match(Layer) bool
// length in bytes of the current encapsulation
@@ -59,6 +65,9 @@ type Layer interface {
// setPrev sets the pointer to the Layer encapsulating this one.
setPrev(Layer)
+
+ // merge overrides the values in the interface with the provided values.
+ merge(Layer) error
}
// LayerBase is the common elements of all layers.
@@ -83,21 +92,59 @@ func (lb *LayerBase) setPrev(l Layer) {
lb.prevLayer = l
}
+// equalLayer compares that two Layer structs match while ignoring field in
+// which either input has a nil and also ignoring the LayerBase of the inputs.
func equalLayer(x, y Layer) bool {
+ if x == nil || y == nil {
+ return true
+ }
+ // opt ignores comparison pairs where either of the inputs is a nil.
opt := cmp.FilterValues(func(x, y interface{}) bool {
- if reflect.ValueOf(x).Kind() == reflect.Ptr && reflect.ValueOf(x).IsNil() {
- return true
- }
- if reflect.ValueOf(y).Kind() == reflect.Ptr && reflect.ValueOf(y).IsNil() {
- return true
+ for _, l := range []interface{}{x, y} {
+ v := reflect.ValueOf(l)
+ if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
+ return true
+ }
}
return false
-
}, cmp.Ignore())
- return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{}))
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
+}
+
+// mergeLayer merges other in layer. Any non-nil value in other overrides the
+// corresponding value in layer. If other is nil, no action is performed.
+func mergeLayer(layer, other Layer) error {
+ if other == nil {
+ return nil
+ }
+ return mergo.Merge(layer, other, mergo.WithOverride)
+}
+
+func stringLayer(l Layer) string {
+ v := reflect.ValueOf(l).Elem()
+ t := v.Type()
+ var ret []string
+ for i := 0; i < v.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := v.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ v = reflect.Indirect(v)
+ if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
+ ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
+ } else {
+ ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
+ }
+ }
+ return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
}
-// Ether can construct and match the ethernet encapsulation.
+// Ether can construct and match an ethernet encapsulation.
type Ether struct {
LayerBase
SrcAddr *tcpip.LinkAddress
@@ -105,6 +152,10 @@ type Ether struct {
Type *tcpip.NetworkProtocolNumber
}
+func (l *Ether) String() string {
+ return stringLayer(l)
+}
+
func (l *Ether) toBytes() ([]byte, error) {
b := make([]byte, header.EthernetMinimumSize)
h := header.Ethernet(b)
@@ -123,7 +174,7 @@ func (l *Ether) toBytes() ([]byte, error) {
fields.Type = header.IPv4ProtocolNumber
default:
// TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n)
+ return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
}
}
h.Encode(fields)
@@ -142,27 +193,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol
return &v
}
-// ParseEther parses the bytes assuming that they start with an ethernet header
+// layerParser parses the input bytes and returns a Layer along with the next
+// layerParser to run. If there is no more parsing to do, the returned
+// layerParser is nil.
+type layerParser func([]byte) (Layer, layerParser)
+
+// parse parses bytes starting with the first layerParser and using successive
+// layerParsers until all the bytes are parsed.
+func parse(parser layerParser, b []byte) Layers {
+ var layers Layers
+ for {
+ var layer Layer
+ layer, parser = parser(b)
+ layers = append(layers, layer)
+ if parser == nil {
+ break
+ }
+ b = b[layer.length():]
+ }
+ layers.linkLayers()
+ return layers
+}
+
+// parseEther parses the bytes assuming that they start with an ethernet header
// and continues parsing further encapsulations.
-func ParseEther(b []byte) (Layers, error) {
+func parseEther(b []byte) (Layer, layerParser) {
h := header.Ethernet(b)
ether := Ether{
SrcAddr: LinkAddress(h.SourceAddress()),
DstAddr: LinkAddress(h.DestinationAddress()),
Type: NetworkProtocolNumber(h.Type()),
}
- layers := Layers{&ether}
+ var nextParser layerParser
switch h.Type() {
case header.IPv4ProtocolNumber:
- moreLayers, err := ParseIPv4(b[ether.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ nextParser = parseIPv4
default:
- // TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %v", b)
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
}
+ return &ether, nextParser
}
func (l *Ether) match(other Layer) bool {
@@ -173,7 +243,13 @@ func (l *Ether) length() int {
return header.EthernetMinimumSize
}
-// IPv4 can construct and match the ethernet excapulation.
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Ether) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv4 can construct and match an IPv4 encapsulation.
type IPv4 struct {
LayerBase
IHL *uint8
@@ -189,6 +265,10 @@ type IPv4 struct {
DstAddr *tcpip.Address
}
+func (l *IPv4) String() string {
+ return stringLayer(l)
+}
+
func (l *IPv4) toBytes() ([]byte, error) {
b := make([]byte, header.IPv4MinimumSize)
h := header.IPv4(b)
@@ -236,9 +316,11 @@ func (l *IPv4) toBytes() ([]byte, error) {
switch n := l.next().(type) {
case *TCP:
fields.Protocol = uint8(header.TCPProtocolNumber)
+ case *UDP:
+ fields.Protocol = uint8(header.UDPProtocolNumber)
default:
- // TODO(b/150301488): Support more protocols, like UDP.
- return nil, fmt.Errorf("can't deduce the ip header's next protocol: %+v", n)
+ // TODO(b/150301488): Support more protocols as needed.
+ return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
}
}
if l.SrcAddr != nil {
@@ -275,9 +357,9 @@ func Address(v tcpip.Address) *tcpip.Address {
return &v
}
-// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and
+// parseIPv4 parses the bytes assuming that they start with an ipv4 header and
// continues parsing further encapsulations.
-func ParseIPv4(b []byte) (Layers, error) {
+func parseIPv4(b []byte) (Layer, layerParser) {
h := header.IPv4(b)
tos, _ := h.TOS()
ipv4 := IPv4{
@@ -293,16 +375,17 @@ func ParseIPv4(b []byte) (Layers, error) {
SrcAddr: Address(h.SourceAddress()),
DstAddr: Address(h.DestinationAddress()),
}
- layers := Layers{&ipv4}
- switch h.Protocol() {
- case uint8(header.TCPProtocolNumber):
- moreLayers, err := ParseTCP(b[ipv4.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ var nextParser layerParser
+ switch h.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ nextParser = parseTCP
+ case header.UDPProtocolNumber:
+ nextParser = parseUDP
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
}
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol())
+ return &ipv4, nextParser
}
func (l *IPv4) match(other Layer) bool {
@@ -316,7 +399,13 @@ func (l *IPv4) length() int {
return int(*l.IHL)
}
-// TCP can construct and match the TCP excapulation.
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// TCP can construct and match a TCP encapsulation.
type TCP struct {
LayerBase
SrcPort *uint16
@@ -330,6 +419,10 @@ type TCP struct {
UrgentPointer *uint16
}
+func (l *TCP) String() string {
+ return stringLayer(l)
+}
+
func (l *TCP) toBytes() ([]byte, error) {
b := make([]byte, header.TCPMinimumSize)
h := header.TCP(b)
@@ -347,12 +440,16 @@ func (l *TCP) toBytes() ([]byte, error) {
}
if l.DataOffset != nil {
h.SetDataOffset(*l.DataOffset)
+ } else {
+ h.SetDataOffset(uint8(l.length()))
}
if l.Flags != nil {
h.SetFlags(*l.Flags)
}
if l.WindowSize != nil {
h.SetWindowSize(*l.WindowSize)
+ } else {
+ h.SetWindowSize(32768)
}
if l.UrgentPointer != nil {
h.SetUrgentPoiner(*l.UrgentPointer)
@@ -361,38 +458,52 @@ func (l *TCP) toBytes() ([]byte, error) {
h.SetChecksum(*l.Checksum)
return h, nil
}
- if err := setChecksum(&h, l); err != nil {
+ if err := setTCPChecksum(&h, l); err != nil {
return nil, err
}
return h, nil
}
-// setChecksum calculates the checksum of the TCP header and sets it in h.
-func setChecksum(h *header.TCP, tcp *TCP) error {
- h.SetChecksum(0)
- tcpLength := uint16(tcp.length())
- current := tcp.next()
- for current != nil {
- tcpLength += uint16(current.length())
- current = current.next()
+// totalLength returns the length of the provided layer and all following
+// layers.
+func totalLength(l Layer) int {
+ var totalLength int
+ for ; l != nil; l = l.next() {
+ totalLength += l.length()
}
+ return totalLength
+}
+// layerChecksum calculates the checksum of the Layer header, including the
+// peusdeochecksum of the layer before it and all the bytes after it..
+func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
+ totalLength := uint16(totalLength(l))
var xsum uint16
- switch s := tcp.prev().(type) {
+ switch s := l.prev().(type) {
case *IPv4:
- xsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, *s.SrcAddr, *s.DstAddr, tcpLength)
+ xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
default:
// TODO(b/150301488): Support more protocols, like IPv6.
- return fmt.Errorf("can't get src and dst addr from previous layer")
+ return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
}
- current = tcp.next()
- for current != nil {
+ var payloadBytes buffer.VectorisedView
+ for current := l.next(); current != nil; current = current.next() {
payload, err := current.toBytes()
if err != nil {
- return fmt.Errorf("can't get bytes for next header: %s", payload)
+ return 0, fmt.Errorf("can't get bytes for next header: %s", payload)
}
- xsum = header.Checksum(payload, xsum)
- current = current.next()
+ payloadBytes.AppendView(payload)
+ }
+ xsum = header.ChecksumVV(payloadBytes, xsum)
+ return xsum, nil
+}
+
+// setTCPChecksum calculates the checksum of the TCP header and sets it in h.
+func setTCPChecksum(h *header.TCP, tcp *TCP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
+ if err != nil {
+ return err
}
h.SetChecksum(^h.CalculateChecksum(xsum))
return nil
@@ -404,9 +515,9 @@ func Uint32(v uint32) *uint32 {
return &v
}
-// ParseTCP parses the bytes assuming that they start with a tcp header and
+// parseTCP parses the bytes assuming that they start with a tcp header and
// continues parsing further encapsulations.
-func ParseTCP(b []byte) (Layers, error) {
+func parseTCP(b []byte) (Layer, layerParser) {
h := header.TCP(b)
tcp := TCP{
SrcPort: Uint16(h.SourcePort()),
@@ -419,12 +530,7 @@ func ParseTCP(b []byte) (Layers, error) {
Checksum: Uint16(h.Checksum()),
UrgentPointer: Uint16(h.UrgentPointer()),
}
- layers := Layers{&tcp}
- moreLayers, err := ParsePayload(b[tcp.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ return &tcp, parsePayload
}
func (l *TCP) match(other Layer) bool {
@@ -440,8 +546,86 @@ func (l *TCP) length() int {
// merge overrides the values in l with the values from other but only in fields
// where the value is not nil.
-func (l *TCP) merge(other TCP) error {
- return mergo.Merge(l, other, mergo.WithOverride)
+func (l *TCP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// UDP can construct and match a UDP encapsulation.
+type UDP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ Length *uint16
+ Checksum *uint16
+}
+
+func (l *UDP) String() string {
+ return stringLayer(l)
+}
+
+func (l *UDP) toBytes() ([]byte, error) {
+ b := make([]byte, header.UDPMinimumSize)
+ h := header.UDP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.Length != nil {
+ h.SetLength(*l.Length)
+ } else {
+ h.SetLength(uint16(totalLength(l)))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setUDPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// setUDPChecksum calculates the checksum of the UDP header and sets it in h.
+func setUDPChecksum(h *header.UDP, udp *UDP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// parseUDP parses the bytes assuming that they start with a udp header and
+// returns the parsed layer and the next parser to use.
+func parseUDP(b []byte) (Layer, layerParser) {
+ h := header.UDP(b)
+ udp := UDP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ Length: Uint16(h.Length()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &udp, parsePayload
+}
+
+func (l *UDP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *UDP) length() int {
+ if l.Length == nil {
+ return header.UDPMinimumSize
+ }
+ return int(*l.Length)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *UDP) merge(other Layer) error {
+ return mergeLayer(l, other)
}
// Payload has bytes beyond OSI layer 4.
@@ -450,13 +634,17 @@ type Payload struct {
Bytes []byte
}
-// ParsePayload parses the bytes assuming that they start with a payload and
+func (l *Payload) String() string {
+ return stringLayer(l)
+}
+
+// parsePayload parses the bytes assuming that they start with a payload and
// continue to the end. There can be no further encapsulations.
-func ParsePayload(b []byte) (Layers, error) {
+func parsePayload(b []byte) (Layer, layerParser) {
payload := Payload{
Bytes: b,
}
- return Layers{&payload}, nil
+ return &payload, nil
}
func (l *Payload) toBytes() ([]byte, error) {
@@ -471,18 +659,33 @@ func (l *Payload) length() int {
return len(l.Bytes)
}
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Payload) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
// Layers is an array of Layer and supports similar functions to Layer.
type Layers []Layer
-func (ls *Layers) toBytes() ([]byte, error) {
+// linkLayers sets the linked-list ponters in ls.
+func (ls *Layers) linkLayers() {
for i, l := range *ls {
if i > 0 {
l.setPrev((*ls)[i-1])
+ } else {
+ l.setPrev(nil)
}
if i+1 < len(*ls) {
l.setNext((*ls)[i+1])
+ } else {
+ l.setNext(nil)
}
}
+}
+
+func (ls *Layers) toBytes() ([]byte, error) {
+ ls.linkLayers()
outBytes := []byte{}
for _, l := range *ls {
layerBytes, err := l.toBytes()
@@ -498,8 +701,8 @@ func (ls *Layers) match(other Layers) bool {
if len(*ls) > len(other) {
return false
}
- for i := 0; i < len(*ls); i++ {
- if !equalLayer((*ls)[i], other[i]) {
+ for i, l := range *ls {
+ if !equalLayer(l, other[i]) {
return false
}
}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
new file mode 100644
index 000000000..b32efda93
--- /dev/null
+++ b/test/packetimpact/testbench/layers_test.go
@@ -0,0 +1,156 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestLayerMatch(t *testing.T) {
+ var nilPayload *Payload
+ noPayload := &Payload{}
+ emptyPayload := &Payload{Bytes: []byte{}}
+ fullPayload := &Payload{Bytes: []byte{1, 2, 3}}
+ emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}}
+ fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}}
+ for _, tt := range []struct {
+ a, b Layer
+ want bool
+ }{
+ {nilPayload, nilPayload, true},
+ {nilPayload, noPayload, true},
+ {nilPayload, emptyPayload, true},
+ {nilPayload, fullPayload, true},
+ {noPayload, noPayload, true},
+ {noPayload, emptyPayload, true},
+ {noPayload, fullPayload, true},
+ {emptyPayload, emptyPayload, true},
+ {emptyPayload, fullPayload, false},
+ {fullPayload, fullPayload, true},
+ {emptyTCP, fullTCP, true},
+ } {
+ if got := tt.a.match(tt.b); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want)
+ }
+ if got := tt.b.match(tt.a); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestLayerStringFormat(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ l Layer
+ want string
+ }{
+ {
+ name: "TCP",
+ l: &TCP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ SeqNum: Uint32(3452155723),
+ AckNum: Uint32(2596996163),
+ DataOffset: Uint8(5),
+ Flags: Uint8(20),
+ WindowSize: Uint16(64240),
+ Checksum: Uint16(0x2e2b),
+ },
+ want: "&testbench.TCP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "SeqNum:3452155723 " +
+ "AckNum:2596996163 " +
+ "DataOffset:5 " +
+ "Flags:20 " +
+ "WindowSize:64240 " +
+ "Checksum:11819" +
+ "}",
+ },
+ {
+ name: "UDP",
+ l: &UDP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ Length: Uint16(12),
+ },
+ want: "&testbench.UDP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "Length:12" +
+ "}",
+ },
+ {
+ name: "IPv4",
+ l: &IPv4{
+ IHL: Uint8(5),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(0),
+ Flags: Uint8(2),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(6),
+ Checksum: Uint16(0x2e2b),
+ SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})),
+ DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})),
+ },
+ want: "&testbench.IPv4{" +
+ "IHL:5 " +
+ "TOS:0 " +
+ "TotalLength:44 " +
+ "ID:0 " +
+ "Flags:2 " +
+ "FragmentOffset:0 " +
+ "TTL:64 " +
+ "Protocol:6 " +
+ "Checksum:11819 " +
+ "SrcAddr:197.34.63.10 " +
+ "DstAddr:197.34.63.20" +
+ "}",
+ },
+ {
+ name: "Ether",
+ l: &Ether{
+ SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})),
+ DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})),
+ Type: NetworkProtocolNumber(4),
+ },
+ want: "&testbench.Ether{" +
+ "SrcAddr:02:42:c5:22:3f:0a " +
+ "DstAddr:02:42:c5:22:3f:14 " +
+ "Type:4" +
+ "}",
+ },
+ {
+ name: "Payload",
+ l: &Payload{
+ Bytes: []byte("Hooray for packetimpact."),
+ },
+ want: "&testbench.Payload{Bytes:\n" +
+ "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" +
+ "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" +
+ "}",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.l.String(); got != tt.want {
+ t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 0c7d0f979..ff722d4a6 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -17,6 +17,7 @@ package testbench
import (
"encoding/binary"
"flag"
+ "fmt"
"math"
"net"
"testing"
@@ -47,6 +48,12 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
if err != nil {
return Sniffer{}, err
}
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil {
+ t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err)
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil {
+ t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
+ }
return Sniffer{
t: t,
fd: snifferFd,
@@ -91,12 +98,36 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
}
}
-// Close the socket that Sniffer is using.
-func (s *Sniffer) Close() {
+// Drain drains the Sniffer's socket receive buffer by receiving until there's
+// nothing else to receive.
+func (s *Sniffer) Drain() {
+ s.t.Helper()
+ flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0)
+ if err != nil {
+ s.t.Fatalf("failed to get sniffer socket fd flags: %s", err)
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil {
+ s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err)
+ }
+ for {
+ buf := make([]byte, maxReadSize)
+ _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK {
+ break
+ }
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil {
+ s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err)
+ }
+}
+
+// close the socket that Sniffer is using.
+func (s *Sniffer) close() error {
if err := unix.Close(s.fd); err != nil {
- s.t.Fatalf("can't close sniffer socket: %s", err)
+ return fmt.Errorf("can't close sniffer socket: %w", err)
}
s.fd = -1
+ return nil
}
// Injector can inject raw frames.
@@ -142,10 +173,11 @@ func (i *Injector) Send(b []byte) {
}
}
-// Close the underlying socket.
-func (i *Injector) Close() {
+// close the underlying socket.
+func (i *Injector) close() error {
if err := unix.Close(i.fd); err != nil {
- i.t.Fatalf("can't close sniffer socket: %s", err)
+ return fmt.Errorf("can't close sniffer socket: %w", err)
}
i.fd = -1
+ return nil
}