summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact/testbench
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r--test/packetimpact/testbench/connections.go72
-rw-r--r--test/packetimpact/testbench/dut.go46
-rw-r--r--test/packetimpact/testbench/layers.go5
3 files changed, 112 insertions, 11 deletions
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index bf104e5ca..6e85d6fab 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -266,14 +266,14 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value {
}
// newTCPState creates a new TCPState.
-func newTCPState(domain int, out, in TCP) (*tcpState, error) {
+func newTCPState(domain int, out, in TCP) (*tcpState, unix.Sockaddr, error) {
portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM)
if err != nil {
- return nil, err
+ return nil, nil, err
}
localPort, err := portFromSockaddr(localAddr)
if err != nil {
- return nil, err
+ return nil, nil, err
}
s := tcpState{
out: TCP{SrcPort: &localPort},
@@ -283,12 +283,12 @@ func newTCPState(domain int, out, in TCP) (*tcpState, error) {
finSent: false,
}
if err := s.out.merge(&out); err != nil {
- return nil, err
+ return nil, nil, err
}
if err := s.in.merge(&in); err != nil {
- return nil, err
+ return nil, nil, err
}
- return &s, nil
+ return &s, localAddr, nil
}
func (s *tcpState) outgoing() Layer {
@@ -606,7 +606,7 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
+ tcpState, localAddr, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
if err != nil {
t.Fatalf("can't make tcpState: %s", err)
}
@@ -623,19 +623,41 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
layerStates: []layerState{etherState, ipv4State, tcpState},
injector: injector,
sniffer: sniffer,
+ localAddr: localAddr,
t: t,
}
}
-// Handshake performs a TCP 3-way handshake. The input Connection should have a
+// Connect performs a TCP 3-way handshake. The input Connection should have a
// final TCP Layer.
-func (conn *TCPIPv4) Handshake() {
+func (conn *TCPIPv4) Connect() {
+ conn.t.Helper()
+
// Send the SYN.
conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
- if synAck == nil {
+ if err != nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
+// The input Connection should have a final TCP Layer.
+func (conn *TCPIPv4) ConnectWithOptions(options []byte) {
+ conn.t.Helper()
+
+ // Send the SYN.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
conn.t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
@@ -655,6 +677,31 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio
return (*Connection)(conn).ExpectFrame(expected, timeout)
}
+// ExpectNextData attempts to receive the next incoming segment for the
+// connection and expects that to match the given layers.
+//
+// It differs from ExpectData() in that here we are only interested in the next
+// received segment, while ExpectData() can receive multiple segments for the
+// connection until there is a match with given layers or a timeout.
+func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ // Receive the first incoming TCP segment for this connection.
+ got, err := conn.ExpectData(&TCP{}, nil, timeout)
+ if err != nil {
+ return nil, err
+ }
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length()))
+ }
+ if !(*Connection)(conn).match(expected, got) {
+ return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got)
+ }
+ return got, nil
+}
+
// Send a packet with reasonable defaults. Potentially override the TCP layer in
// the connection with the provided layer and add additionLayers.
func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
@@ -703,6 +750,11 @@ func (conn *TCPIPv4) SynAck() *TCP {
return conn.state().synAck
}
+// LocalAddr gets the local socket address of this connection.
+func (conn *TCPIPv4) LocalAddr() unix.Sockaddr {
+ return conn.localAddr
+}
+
// IPv6Conn maintains the state for all the layers in a IPv6 connection.
type IPv6Conn Connection
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index b919a3c2e..2a2afecb5 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -241,7 +241,9 @@ func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) {
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.ConnectWithErrno(ctx, fd, sa)
- if ret != 0 {
+ // Ignore 'operation in progress' error that can be returned when the socket
+ // is non-blocking.
+ if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 {
dut.t.Fatalf("failed to connect socket: %s", err)
}
}
@@ -260,6 +262,35 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
+// Fcntl calls fcntl on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use FcntlWithErrno.
+func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg)
+ if ret == -1 {
+ dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err)
+ }
+ return ret
+}
+
+// FcntlWithErrno calls fcntl on the DUT.
+func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.FcntlRequest{
+ Fd: fd,
+ Cmd: cmd,
+ Arg: arg,
+ }
+ resp, err := dut.posixServer.Fcntl(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Fcntl: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
// GetSockName calls getsockname on the DUT and causes a fatal test failure if
// it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockNameWithErrno.
@@ -476,6 +507,19 @@ func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, f
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
+// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking
+// is true, otherwise it will clear the flag.
+func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) {
+ dut.t.Helper()
+ flags := dut.Fcntl(fd, unix.F_GETFL, 0)
+ if nonblocking {
+ flags |= unix.O_NONBLOCK
+ } else {
+ flags &= ^unix.O_NONBLOCK
+ }
+ dut.Fcntl(fd, unix.F_SETFL, flags)
+}
+
func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) {
dut.t.Helper()
req := pb.SetSockOptRequest{
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 1b0e5b8fc..560c4111b 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -939,6 +939,11 @@ func (l *Payload) ToBytes() ([]byte, error) {
return l.Bytes, nil
}
+// Length returns payload byte length.
+func (l *Payload) Length() int {
+ return l.length()
+}
+
func (l *Payload) match(other Layer) bool {
return equalLayer(l, other)
}