summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact/testbench/connections.go
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact/testbench/connections.go')
-rw-r--r--test/packetimpact/testbench/connections.go773
1 files changed, 499 insertions, 274 deletions
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 579da59c3..f84fd8ba7 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -21,10 +21,12 @@ import (
"fmt"
"math/rand"
"net"
+ "strings"
"testing"
"time"
"github.com/mohae/deepcopy"
+ "go.uber.org/multierr"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -62,384 +64,607 @@ func pickPort() (int, uint16, error) {
return fd, uint16(newSockAddrInet4.Port), nil
}
-// TCPIPv4 maintains state about a TCP/IPv4 connection.
-type TCPIPv4 struct {
- outgoing Layers
- incoming Layers
- LocalSeqNum seqnum.Value
- RemoteSeqNum seqnum.Value
- SynAck *TCP
- sniffer Sniffer
- injector Injector
- portPickerFD int
- t *testing.T
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame.
+ outgoing() Layer
+
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
}
-// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It
-// is the third, after Ethernet and IPv4.
-const tcpLayerIndex int = 2
+// etherState maintains state about an Ethernet connection.
+type etherState struct {
+ out, in Ether
+}
-// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
-func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+var _ layerState = (*etherState)(nil)
+
+// newEtherState creates a new etherState.
+func newEtherState(out, in Ether) (*etherState, error) {
lMAC, err := tcpip.ParseMACAddress(*localMAC)
if err != nil {
- t.Fatalf("can't parse localMAC %q: %s", *localMAC, err)
+ return nil, err
}
rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
if err != nil {
- t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err)
+ return nil, err
}
-
- portPickerFD, localPort, err := pickPort()
- if err != nil {
- t.Fatalf("can't pick a port: %s", err)
+ s := etherState{
+ out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
+ in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
}
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *etherState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *etherState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*etherState) sent(Layer) error {
+ return nil
+}
+
+func (*etherState) received(Layer) error {
+ return nil
+}
+
+func (*etherState) close() error {
+ return nil
+}
+
+// ipv4State maintains state about an IPv4 connection.
+type ipv4State struct {
+ out, in IPv4
+}
+
+var _ layerState = (*ipv4State)(nil)
+
+// newIPv4State creates a new ipv4State.
+func newIPv4State(out, in IPv4) (*ipv4State, error) {
lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
-
- sniffer, err := NewSniffer(t)
- if err != nil {
- t.Fatalf("can't make new sniffer: %s", err)
+ s := ipv4State{
+ out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
}
+ return &s, nil
+}
- injector, err := NewInjector(t)
+func (s *ipv4State) outgoing() Layer {
+ return &s.out
+}
+
+func (s *ipv4State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*ipv4State) sent(Layer) error {
+ return nil
+}
+
+func (*ipv4State) received(Layer) error {
+ return nil
+}
+
+func (*ipv4State) close() error {
+ return nil
+}
+
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+
+var _ layerState = (*tcpState)(nil)
+
+// SeqNumValue is a helper routine that allocates a new seqnum.Value value to
+// store v and returns a pointer to it.
+func SeqNumValue(v seqnum.Value) *seqnum.Value {
+ return &v
+}
+
+// newTCPState creates a new TCPState.
+func newTCPState(out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort()
if err != nil {
- t.Fatalf("can't make new injector: %s", err)
+ return nil, err
+ }
+ s := tcpState{
+ out: TCP{SrcPort: &localPort},
+ in: TCP{DstPort: &localPort},
+ localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())),
+ portPickerFD: portPickerFD,
+ finSent: false,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
}
+ return &s, nil
+}
- newOutgoingTCP := &TCP{
- SrcPort: &localPort,
+func (s *tcpState) outgoing() Layer {
+ newOutgoing := deepcopy.Copy(s.out).(TCP)
+ if s.localSeqNum != nil {
+ newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum))
}
- if err := newOutgoingTCP.merge(outgoingTCP); err != nil {
- t.Fatalf("can't merge %+v into %+v: %s", outgoingTCP, newOutgoingTCP, err)
+ if s.remoteSeqNum != nil {
+ newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum))
}
- newIncomingTCP := &TCP{
- DstPort: &localPort,
+ return &newOutgoing
+}
+
+func (s *tcpState) incoming(received Layer) Layer {
+ tcpReceived, ok := received.(*TCP)
+ if !ok {
+ return nil
}
- if err := newIncomingTCP.merge(incomingTCP); err != nil {
- t.Fatalf("can't merge %+v into %+v: %s", incomingTCP, newIncomingTCP, err)
+ newIn := deepcopy.Copy(s.in).(TCP)
+ if s.remoteSeqNum != nil {
+ newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
}
- return TCPIPv4{
- outgoing: Layers{
- &Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
- &IPv4{SrcAddr: &lIP, DstAddr: &rIP},
- newOutgoingTCP},
- incoming: Layers{
- &Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
- &IPv4{SrcAddr: &rIP, DstAddr: &lIP},
- newIncomingTCP},
- sniffer: sniffer,
- injector: injector,
- portPickerFD: portPickerFD,
- t: t,
- LocalSeqNum: seqnum.Value(rand.Uint32()),
+ if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ // The caller didn't specify an AckNum so we'll expect the calculated one,
+ // but only if the ACK flag is set because the AckNum is not valid in a
+ // header if ACK is not set.
+ newIn.AckNum = Uint32(uint32(*s.localSeqNum))
}
+ return &newIn
}
-// Close the injector and sniffer associated with this connection.
-func (conn *TCPIPv4) Close() {
- conn.sniffer.Close()
- conn.injector.Close()
- if err := unix.Close(conn.portPickerFD); err != nil {
- conn.t.Fatalf("can't close portPickerFD: %s", err)
+func (s *tcpState) sent(sent Layer) error {
+ tcp, ok := sent.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", sent)
}
- conn.portPickerFD = -1
+ if !s.finSent {
+ // update localSeqNum by the payload only when FIN is not yet sent by us
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.localSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ }
+ if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.localSeqNum.UpdateForward(1)
+ }
+ if *tcp.Flags&(header.TCPFlagFin) != 0 {
+ s.finSent = true
+ }
+ return nil
}
-// CreateFrame builds a frame for the connection with tcp overriding defaults
-// and additionalLayers added after the TCP header.
-func (conn *TCPIPv4) CreateFrame(tcp TCP, additionalLayers ...Layer) Layers {
- if tcp.SeqNum == nil {
- tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum))
+func (s *tcpState) received(l Layer) error {
+ tcp, ok := l.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", l)
}
- if tcp.AckNum == nil {
- tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum))
+ s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum))
+ if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.remoteSeqNum.UpdateForward(1)
}
- layersToSend := deepcopy.Copy(conn.outgoing).(Layers)
- if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil {
- conn.t.Fatalf("can't merge %+v into %+v: %s", tcp, layersToSend[tcpLayerIndex], err)
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.remoteSeqNum.UpdateForward(seqnum.Size(current.length()))
}
- layersToSend = append(layersToSend, additionalLayers...)
- return layersToSend
+ return nil
}
-// SendFrame sends a frame with reasonable defaults.
-func (conn *TCPIPv4) SendFrame(frame Layers) {
- outBytes, err := frame.toBytes()
- if err != nil {
- conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
+// close frees the port associated with this connection.
+func (s *tcpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
}
- conn.injector.Send(outBytes)
+ s.portPickerFD = -1
+ return nil
+}
- // Compute the next TCP sequence number.
- for i := tcpLayerIndex + 1; i < len(frame); i++ {
- conn.LocalSeqNum.UpdateForward(seqnum.Size(frame[i].length()))
+// udpState maintains state about a UDP connection.
+type udpState struct {
+ out, in UDP
+ portPickerFD int
+}
+
+var _ layerState = (*udpState)(nil)
+
+// newUDPState creates a new udpState.
+func newUDPState(out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort()
+ if err != nil {
+ return nil, err
}
- tcp := frame[tcpLayerIndex].(*TCP)
- if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
- conn.LocalSeqNum.UpdateForward(1)
+ s := udpState{
+ out: UDP{SrcPort: &localPort},
+ in: UDP{DstPort: &localPort},
+ portPickerFD: portPickerFD,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
}
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
}
-// Send a packet with reasonable defaults and override some fields by tcp.
-func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
- conn.SendFrame(conn.CreateFrame(tcp, additionalLayers...))
+func (s *udpState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *udpState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*udpState) sent(l Layer) error {
+ return nil
+}
+
+func (*udpState) received(l Layer) error {
+ return nil
}
-// Recv gets a packet from the sniffer within the timeout provided.
-// If no packet arrives before the timeout, it returns nil.
-func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
- layers := conn.RecvFrame(timeout)
- if tcpLayerIndex < len(layers) {
- return layers[tcpLayerIndex].(*TCP)
+// close frees the port associated with this connection.
+func (s *udpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
}
+ s.portPickerFD = -1
return nil
}
-// RecvFrame gets a frame (of type Layers) within the timeout provided.
-// If no frame arrives before the timeout, it returns nil.
-func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers {
- deadline := time.Now().Add(timeout)
- for {
- timeout = time.Until(deadline)
- if timeout <= 0 {
- break
- }
- b := conn.sniffer.Recv(timeout)
- if b == nil {
- break
- }
- layers, err := ParseEther(b)
- if err != nil {
- conn.t.Logf("can't parse frame: %s", err)
- continue // Ignore packets that can't be parsed.
- }
- if !conn.incoming.match(layers) {
- continue // Ignore packets that don't match the expected incoming.
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+ t *testing.T
+}
+
+// match tries to match each Layer in received against the incoming filter. If
+// received is longer than layerStates then that may still count as a match. The
+// reverse is never a match. override overrides the default matchers for each
+// Layer.
+func (conn *Connection) match(override, received Layers) bool {
+ if len(received) < len(conn.layerStates) {
+ return false
+ }
+ for i, s := range conn.layerStates {
+ toMatch := s.incoming(received[i])
+ if toMatch == nil {
+ return false
}
- tcpHeader := (layers[tcpLayerIndex]).(*TCP)
- conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum)
- if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
- conn.RemoteSeqNum.UpdateForward(1)
+ if i < len(override) {
+ toMatch.merge(override[i])
}
- for i := tcpLayerIndex + 1; i < len(layers); i++ {
- conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length()))
+ if !toMatch.match(received[i]) {
+ return false
}
- return layers
}
- return nil
+ return true
}
-// Expect a packet that matches the provided tcp within the timeout specified.
-// If it doesn't arrive in time, it returns nil.
-func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP {
- // We cannot implement this directly using ExpectFrame as we cannot specify
- // the Payload part.
- deadline := time.Now().Add(timeout)
- for {
- timeout = time.Until(deadline)
- if timeout <= 0 {
- return nil
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close() {
+ errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
+ for _, s := range conn.layerStates {
+ if err := s.close(); err != nil {
+ errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err))
}
- gotTCP := conn.Recv(timeout)
- if tcp.match(gotTCP) {
- return gotTCP
+ }
+ if errs != nil {
+ conn.t.Fatalf("unable to close %+v: %s", conn, errs)
+ }
+}
+
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ var layersToSend Layers
+ for _, s := range conn.layerStates {
+ layersToSend = append(layersToSend, s.outgoing())
+ }
+ if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil {
+ conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err)
+ }
+ layersToSend = append(layersToSend, additionalLayers...)
+ return layersToSend
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(frame Layers) {
+ outBytes, err := frame.toBytes()
+ if err != nil {
+ conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
+ }
+ conn.injector.Send(outBytes)
+
+ // frame might have nil values where the caller wanted to use default values.
+ // sentFrame will have no nil values in it because it comes from parsing the
+ // bytes that were actually sent.
+ sentFrame := parse(parseEther, outBytes)
+ // Update the state of each layer based on what was sent.
+ for i, s := range conn.layerStates {
+ if err := s.sent(sentFrame[i]); err != nil {
+ conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
}
}
}
-// ExpectFrame expects a frame that matches the specified layers within the
+// Send a packet with reasonable defaults. Potentially override the final layer
+// in the connection with the provided layer and add additionLayers.
+func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {
+ conn.SendFrame(conn.CreateFrame(layer, additionalLayers...))
+}
+
+// recvFrame gets the next successfully parsed frame (of type Layers) within the
+// timeout provided. If no parsable frame arrives before the timeout, it returns
+// nil.
+func (conn *Connection) recvFrame(timeout time.Duration) Layers {
+ if timeout <= 0 {
+ return nil
+ }
+ b := conn.sniffer.Recv(timeout)
+ if b == nil {
+ return nil
+ }
+ return parse(parseEther, b)
+}
+
+// Expect a frame with the final layerStates layer matching the provided Layer
+// within the timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
+ // Make a frame that will ignore all but the final layer.
+ layers := make([]Layer, len(conn.layerStates))
+ layers[len(layers)-1] = layer
+
+ gotFrame, err := conn.ExpectFrame(layers, timeout)
+ if err != nil {
+ return nil, err
+ }
+ if len(conn.layerStates)-1 < len(gotFrame) {
+ return gotFrame[len(conn.layerStates)-1], nil
+ }
+ conn.t.Fatal("the received frame should be at least as long as the expected layers")
+ return nil, fmt.Errorf("the received frame should be at least as long as the expected layers")
+}
+
+// ExpectFrame expects a frame that matches the provided Layers within the
// timeout specified. If it doesn't arrive in time, it returns nil.
-func (conn *TCPIPv4) ExpectFrame(layers Layers, timeout time.Duration) Layers {
+func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {
deadline := time.Now().Add(timeout)
+ var allLayers []string
for {
- timeout = time.Until(deadline)
- if timeout <= 0 {
- return nil
+ var gotLayers Layers
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotLayers = conn.recvFrame(timeout)
}
- gotLayers := conn.RecvFrame(timeout)
- if layers.match(gotLayers) {
- return gotLayers
+ if gotLayers == nil {
+ return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n"))
}
+ if conn.match(layers, gotLayers) {
+ for i, s := range conn.layerStates {
+ if err := s.received(gotLayers[i]); err != nil {
+ conn.t.Fatal(err)
+ }
+ }
+ return gotLayers, nil
+ }
+ allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers))
}
}
-// ExpectData is a convenient method that expects a TCP packet along with
-// the payload to arrive within the timeout specified. If it doesn't arrive
-// in time, it causes a fatal test failure.
-func (conn *TCPIPv4) ExpectData(tcp TCP, data []byte, timeout time.Duration) {
- expected := []Layer{&Ether{}, &IPv4{}, &tcp}
- if len(data) > 0 {
- expected = append(expected, &Payload{Bytes: data})
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain() {
+ conn.sniffer.Drain()
+}
+
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+
+// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
+func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
}
- if conn.ExpectFrame(expected, timeout) == nil {
- conn.t.Fatalf("expected to get a TCP frame %s with payload %x", &tcp, data)
+ tcpState, err := newTCPState(outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
}
}
-// Handshake performs a TCP 3-way handshake.
+// Handshake performs a TCP 3-way handshake. The input Connection should have a
+// final TCP Layer.
func (conn *TCPIPv4) Handshake() {
// Send the SYN.
conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
- conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
- if conn.SynAck == nil {
- conn.t.Fatalf("didn't get synack during handshake")
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if synAck == nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
}
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
}
-// UDPIPv4 maintains state about a UDP/IPv4 connection.
-type UDPIPv4 struct {
- outgoing Layers
- incoming Layers
- sniffer Sniffer
- injector Injector
- portPickerFD int
- t *testing.T
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(expected, timeout)
+}
+
+// Send a packet with reasonable defaults. Potentially override the TCP layer in
+// the connection with the provided layer and add additionLayers.
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ (*Connection)(conn).Send(&tcp, additionalLayers...)
+}
+
+// Close frees associated resources held by the TCPIPv4 connection.
+func (conn *TCPIPv4) Close() {
+ (*Connection)(conn).Close()
}
-// udpLayerIndex is the position of the UDP layer in the UDPIPv4 connection. It
-// is the third, after Ethernet and IPv4.
-const udpLayerIndex int = 2
+// Expect a frame with the TCP layer matching the provided TCP within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
+ layer, err := (*Connection)(conn).Expect(&tcp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotTCP, ok := layer.(*TCP)
+ if !ok {
+ conn.t.Fatalf("expected %s to be TCP", layer)
+ }
+ return gotTCP, err
+}
+
+func (conn *TCPIPv4) state() *tcpState {
+ state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ }
+ return state
+}
+
+// RemoteSeqNum returns the next expected sequence number from the DUT.
+func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
+ return conn.state().remoteSeqNum
+}
+
+// LocalSeqNum returns the next sequence number to send from the testbench.
+func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
+ return conn.state().localSeqNum
+}
+
+// SynAck returns the SynAck that was part of the handshake.
+func (conn *TCPIPv4) SynAck() *TCP {
+ return conn.state().synAck
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
+
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
- lMAC, err := tcpip.ParseMACAddress(*localMAC)
+ etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
- t.Fatalf("can't parse localMAC %q: %s", *localMAC, err)
+ t.Fatalf("can't make etherState: %s", err)
}
-
- rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
if err != nil {
- t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err)
+ t.Fatalf("can't make ipv4State: %s", err)
}
-
- portPickerFD, localPort, err := pickPort()
+ tcpState, err := newUDPState(outgoingUDP, incomingUDP)
if err != nil {
- t.Fatalf("can't pick a port: %s", err)
+ t.Fatalf("can't make udpState: %s", err)
}
- lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
- rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
-
- sniffer, err := NewSniffer(t)
+ injector, err := NewInjector(t)
if err != nil {
- t.Fatalf("can't make new sniffer: %s", err)
+ t.Fatalf("can't make injector: %s", err)
}
-
- injector, err := NewInjector(t)
+ sniffer, err := NewSniffer(t)
if err != nil {
- t.Fatalf("can't make new injector: %s", err)
+ t.Fatalf("can't make sniffer: %s", err)
}
- newOutgoingUDP := &UDP{
- SrcPort: &localPort,
- }
- if err := newOutgoingUDP.merge(outgoingUDP); err != nil {
- t.Fatalf("can't merge %+v into %+v: %s", outgoingUDP, newOutgoingUDP, err)
- }
- newIncomingUDP := &UDP{
- DstPort: &localPort,
- }
- if err := newIncomingUDP.merge(incomingUDP); err != nil {
- t.Fatalf("can't merge %+v into %+v: %s", incomingUDP, newIncomingUDP, err)
- }
return UDPIPv4{
- outgoing: Layers{
- &Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
- &IPv4{SrcAddr: &lIP, DstAddr: &rIP},
- newOutgoingUDP},
- incoming: Layers{
- &Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
- &IPv4{SrcAddr: &rIP, DstAddr: &lIP},
- newIncomingUDP},
- sniffer: sniffer,
- injector: injector,
- portPickerFD: portPickerFD,
- t: t,
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
}
}
-// Close the injector and sniffer associated with this connection.
-func (conn *UDPIPv4) Close() {
- conn.sniffer.Close()
- conn.injector.Close()
- if err := unix.Close(conn.portPickerFD); err != nil {
- conn.t.Fatalf("can't close portPickerFD: %s", err)
- }
- conn.portPickerFD = -1
-}
-
-// CreateFrame builds a frame for the connection with the provided udp
-// overriding defaults and the additionalLayers added after the UDP header.
-func (conn *UDPIPv4) CreateFrame(udp UDP, additionalLayers ...Layer) Layers {
- layersToSend := deepcopy.Copy(conn.outgoing).(Layers)
- if err := layersToSend[udpLayerIndex].(*UDP).merge(udp); err != nil {
- conn.t.Fatalf("can't merge %+v into %+v: %s", udp, layersToSend[udpLayerIndex], err)
- }
- layersToSend = append(layersToSend, additionalLayers...)
- return layersToSend
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
}
-// SendFrame sends a frame with reasonable defaults.
+// SendFrame sends a frame on the wire and updates the state of all layers.
func (conn *UDPIPv4) SendFrame(frame Layers) {
- outBytes, err := frame.toBytes()
- if err != nil {
- conn.t.Fatalf("can't build outgoing UDP packet: %s", err)
- }
- conn.injector.Send(outBytes)
+ (*Connection)(conn).SendFrame(frame)
}
-// Send a packet with reasonable defaults and override some fields by udp.
-func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
- conn.SendFrame(conn.CreateFrame(udp, additionalLayers...))
-}
-
-// Recv gets a packet from the sniffer within the timeout provided. If no packet
-// arrives before the timeout, it returns nil.
-func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP {
- deadline := time.Now().Add(timeout)
- for {
- timeout = time.Until(deadline)
- if timeout <= 0 {
- break
- }
- b := conn.sniffer.Recv(timeout)
- if b == nil {
- break
- }
- layers, err := ParseEther(b)
- if err != nil {
- conn.t.Logf("can't parse frame: %s", err)
- continue // Ignore packets that can't be parsed.
- }
- if !conn.incoming.match(layers) {
- continue // Ignore packets that don't match the expected incoming.
- }
- return (layers[udpLayerIndex]).(*UDP)
- }
- return nil
+// Close frees associated resources held by the UDPIPv4 connection.
+func (conn *UDPIPv4) Close() {
+ (*Connection)(conn).Close()
}
-// Expect a packet that matches the provided udp within the timeout specified.
-// If it doesn't arrive in time, the test fails.
-func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) *UDP {
- deadline := time.Now().Add(timeout)
- for {
- timeout = time.Until(deadline)
- if timeout <= 0 {
- return nil
- }
- gotUDP := conn.Recv(timeout)
- if gotUDP == nil {
- return nil
- }
- if udp.match(gotUDP) {
- return gotUDP
- }
- }
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv4) Drain() {
+ conn.sniffer.Drain()
}