diff options
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r-- | test/packetimpact/testbench/connections.go | 72 | ||||
-rw-r--r-- | test/packetimpact/testbench/dut.go | 46 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers.go | 5 |
3 files changed, 112 insertions, 11 deletions
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index bf104e5ca..6e85d6fab 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -266,14 +266,14 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value { } // newTCPState creates a new TCPState. -func newTCPState(domain int, out, in TCP) (*tcpState, error) { +func newTCPState(domain int, out, in TCP) (*tcpState, unix.Sockaddr, error) { portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM) if err != nil { - return nil, err + return nil, nil, err } localPort, err := portFromSockaddr(localAddr) if err != nil { - return nil, err + return nil, nil, err } s := tcpState{ out: TCP{SrcPort: &localPort}, @@ -283,12 +283,12 @@ func newTCPState(domain int, out, in TCP) (*tcpState, error) { finSent: false, } if err := s.out.merge(&out); err != nil { - return nil, err + return nil, nil, err } if err := s.in.merge(&in); err != nil { - return nil, err + return nil, nil, err } - return &s, nil + return &s, localAddr, nil } func (s *tcpState) outgoing() Layer { @@ -606,7 +606,7 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { if err != nil { t.Fatalf("can't make ipv4State: %s", err) } - tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) + tcpState, localAddr, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) if err != nil { t.Fatalf("can't make tcpState: %s", err) } @@ -623,19 +623,41 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { layerStates: []layerState{etherState, ipv4State, tcpState}, injector: injector, sniffer: sniffer, + localAddr: localAddr, t: t, } } -// Handshake performs a TCP 3-way handshake. The input Connection should have a +// Connect performs a TCP 3-way handshake. The input Connection should have a // final TCP Layer. -func (conn *TCPIPv4) Handshake() { +func (conn *TCPIPv4) Connect() { + conn.t.Helper() + // Send the SYN. conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if synAck == nil { + if err != nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck + + // Send an ACK. + conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) +} + +// ConnectWithOptions performs a TCP 3-way handshake with given TCP options. +// The input Connection should have a final TCP Layer. +func (conn *TCPIPv4) ConnectWithOptions(options []byte) { + conn.t.Helper() + + // Send the SYN. + conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + + // Wait for the SYN-ACK. + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { conn.t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck @@ -655,6 +677,31 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio return (*Connection)(conn).ExpectFrame(expected, timeout) } +// ExpectNextData attempts to receive the next incoming segment for the +// connection and expects that to match the given layers. +// +// It differs from ExpectData() in that here we are only interested in the next +// received segment, while ExpectData() can receive multiple segments for the +// connection until there is a match with given layers or a timeout. +func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + // Receive the first incoming TCP segment for this connection. + got, err := conn.ExpectData(&TCP{}, nil, timeout) + if err != nil { + return nil, err + } + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length())) + } + if !(*Connection)(conn).match(expected, got) { + return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) + } + return got, nil +} + // Send a packet with reasonable defaults. Potentially override the TCP layer in // the connection with the provided layer and add additionLayers. func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { @@ -703,6 +750,11 @@ func (conn *TCPIPv4) SynAck() *TCP { return conn.state().synAck } +// LocalAddr gets the local socket address of this connection. +func (conn *TCPIPv4) LocalAddr() unix.Sockaddr { + return conn.localAddr +} + // IPv6Conn maintains the state for all the layers in a IPv6 connection. type IPv6Conn Connection diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index b919a3c2e..2a2afecb5 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -241,7 +241,9 @@ func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) { ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.ConnectWithErrno(ctx, fd, sa) - if ret != 0 { + // Ignore 'operation in progress' error that can be returned when the socket + // is non-blocking. + if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { dut.t.Fatalf("failed to connect socket: %s", err) } } @@ -260,6 +262,35 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } +// Fcntl calls fcntl on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use FcntlWithErrno. +func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg) + if ret == -1 { + dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) + } + return ret +} + +// FcntlWithErrno calls fcntl on the DUT. +func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) { + dut.t.Helper() + req := pb.FcntlRequest{ + Fd: fd, + Cmd: cmd, + Arg: arg, + } + resp, err := dut.posixServer.Fcntl(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Fcntl: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + // GetSockName calls getsockname on the DUT and causes a fatal test failure if // it doesn't succeed. If more control over the timeout or error handling is // needed, use GetSockNameWithErrno. @@ -476,6 +507,19 @@ func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, f return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } +// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking +// is true, otherwise it will clear the flag. +func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) { + dut.t.Helper() + flags := dut.Fcntl(fd, unix.F_GETFL, 0) + if nonblocking { + flags |= unix.O_NONBLOCK + } else { + flags &= ^unix.O_NONBLOCK + } + dut.Fcntl(fd, unix.F_SETFL, flags) +} + func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { dut.t.Helper() req := pb.SetSockOptRequest{ diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 1b0e5b8fc..560c4111b 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -939,6 +939,11 @@ func (l *Payload) ToBytes() ([]byte, error) { return l.Bytes, nil } +// Length returns payload byte length. +func (l *Payload) Length() int { + return l.length() +} + func (l *Payload) match(other Layer) bool { return equalLayer(l, other) } |