diff options
Diffstat (limited to 'test/packetimpact/testbench/connections.go')
-rw-r--r-- | test/packetimpact/testbench/connections.go | 72 |
1 files changed, 62 insertions, 10 deletions
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index bf104e5ca..6e85d6fab 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -266,14 +266,14 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value { } // newTCPState creates a new TCPState. -func newTCPState(domain int, out, in TCP) (*tcpState, error) { +func newTCPState(domain int, out, in TCP) (*tcpState, unix.Sockaddr, error) { portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM) if err != nil { - return nil, err + return nil, nil, err } localPort, err := portFromSockaddr(localAddr) if err != nil { - return nil, err + return nil, nil, err } s := tcpState{ out: TCP{SrcPort: &localPort}, @@ -283,12 +283,12 @@ func newTCPState(domain int, out, in TCP) (*tcpState, error) { finSent: false, } if err := s.out.merge(&out); err != nil { - return nil, err + return nil, nil, err } if err := s.in.merge(&in); err != nil { - return nil, err + return nil, nil, err } - return &s, nil + return &s, localAddr, nil } func (s *tcpState) outgoing() Layer { @@ -606,7 +606,7 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { if err != nil { t.Fatalf("can't make ipv4State: %s", err) } - tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) + tcpState, localAddr, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) if err != nil { t.Fatalf("can't make tcpState: %s", err) } @@ -623,19 +623,41 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { layerStates: []layerState{etherState, ipv4State, tcpState}, injector: injector, sniffer: sniffer, + localAddr: localAddr, t: t, } } -// Handshake performs a TCP 3-way handshake. The input Connection should have a +// Connect performs a TCP 3-way handshake. The input Connection should have a // final TCP Layer. -func (conn *TCPIPv4) Handshake() { +func (conn *TCPIPv4) Connect() { + conn.t.Helper() + // Send the SYN. conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if synAck == nil { + if err != nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck + + // Send an ACK. + conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) +} + +// ConnectWithOptions performs a TCP 3-way handshake with given TCP options. +// The input Connection should have a final TCP Layer. +func (conn *TCPIPv4) ConnectWithOptions(options []byte) { + conn.t.Helper() + + // Send the SYN. + conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + + // Wait for the SYN-ACK. + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { conn.t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck @@ -655,6 +677,31 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio return (*Connection)(conn).ExpectFrame(expected, timeout) } +// ExpectNextData attempts to receive the next incoming segment for the +// connection and expects that to match the given layers. +// +// It differs from ExpectData() in that here we are only interested in the next +// received segment, while ExpectData() can receive multiple segments for the +// connection until there is a match with given layers or a timeout. +func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + // Receive the first incoming TCP segment for this connection. + got, err := conn.ExpectData(&TCP{}, nil, timeout) + if err != nil { + return nil, err + } + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length())) + } + if !(*Connection)(conn).match(expected, got) { + return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) + } + return got, nil +} + // Send a packet with reasonable defaults. Potentially override the TCP layer in // the connection with the provided layer and add additionLayers. func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { @@ -703,6 +750,11 @@ func (conn *TCPIPv4) SynAck() *TCP { return conn.state().synAck } +// LocalAddr gets the local socket address of this connection. +func (conn *TCPIPv4) LocalAddr() unix.Sockaddr { + return conn.localAddr +} + // IPv6Conn maintains the state for all the layers in a IPv6 connection. type IPv6Conn Connection |