diff options
Diffstat (limited to 'test/packetimpact')
-rw-r--r-- | test/packetimpact/testbench/BUILD | 2 | ||||
-rw-r--r-- | test/packetimpact/testbench/connections.go | 773 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers.go | 160 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers_test.go | 109 | ||||
-rw-r--r-- | test/packetimpact/testbench/rawsockets.go | 38 | ||||
-rw-r--r-- | test/packetimpact/tests/BUILD | 48 | ||||
-rw-r--r-- | test/packetimpact/tests/fin_wait2_timeout_test.go | 14 | ||||
-rw-r--r-- | test/packetimpact/tests/tcp_close_wait_ack_test.go | 102 | ||||
-rw-r--r-- | test/packetimpact/tests/tcp_noaccept_close_rst_test.go | 37 | ||||
-rw-r--r-- | test/packetimpact/tests/tcp_outside_the_window_test.go | 88 | ||||
-rw-r--r-- | test/packetimpact/tests/tcp_should_piggyback_test.go | 59 | ||||
-rw-r--r-- | test/packetimpact/tests/tcp_window_shrink_test.go | 18 | ||||
-rw-r--r-- | test/packetimpact/tests/udp_recv_multicast_test.go | 2 |
13 files changed, 1104 insertions, 346 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 199823419..b6a254882 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -28,6 +28,7 @@ go_library( "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//keepalive:go_default_library", "@org_golang_x_sys//unix:go_default_library", + "@org_uber_go_multierr//:go_default_library", ], ) @@ -36,4 +37,5 @@ go_test( size = "small", srcs = ["layers_test.go"], library = ":testbench", + deps = ["//pkg/tcpip"], ) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 579da59c3..f84fd8ba7 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -21,10 +21,12 @@ import ( "fmt" "math/rand" "net" + "strings" "testing" "time" "github.com/mohae/deepcopy" + "go.uber.org/multierr" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -62,384 +64,607 @@ func pickPort() (int, uint16, error) { return fd, uint16(newSockAddrInet4.Port), nil } -// TCPIPv4 maintains state about a TCP/IPv4 connection. -type TCPIPv4 struct { - outgoing Layers - incoming Layers - LocalSeqNum seqnum.Value - RemoteSeqNum seqnum.Value - SynAck *TCP - sniffer Sniffer - injector Injector - portPickerFD int - t *testing.T +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. + outgoing() Layer + + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error } -// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It -// is the third, after Ethernet and IPv4. -const tcpLayerIndex int = 2 +// etherState maintains state about an Ethernet connection. +type etherState struct { + out, in Ether +} -// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. -func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { +var _ layerState = (*etherState)(nil) + +// newEtherState creates a new etherState. +func newEtherState(out, in Ether) (*etherState, error) { lMAC, err := tcpip.ParseMACAddress(*localMAC) if err != nil { - t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) + return nil, err } rMAC, err := tcpip.ParseMACAddress(*remoteMAC) if err != nil { - t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err) + return nil, err } - - portPickerFD, localPort, err := pickPort() - if err != nil { - t.Fatalf("can't pick a port: %s", err) + s := etherState{ + out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, + in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *etherState) outgoing() Layer { + return &s.out +} + +func (s *etherState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*etherState) sent(Layer) error { + return nil +} + +func (*etherState) received(Layer) error { + return nil +} + +func (*etherState) close() error { + return nil +} + +// ipv4State maintains state about an IPv4 connection. +type ipv4State struct { + out, in IPv4 +} + +var _ layerState = (*ipv4State)(nil) + +// newIPv4State creates a new ipv4State. +func newIPv4State(out, in IPv4) (*ipv4State, error) { lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) - - sniffer, err := NewSniffer(t) - if err != nil { - t.Fatalf("can't make new sniffer: %s", err) + s := ipv4State{ + out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err } + return &s, nil +} - injector, err := NewInjector(t) +func (s *ipv4State) outgoing() Layer { + return &s.out +} + +func (s *ipv4State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*ipv4State) sent(Layer) error { + return nil +} + +func (*ipv4State) received(Layer) error { + return nil +} + +func (*ipv4State) close() error { + return nil +} + +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} + +var _ layerState = (*tcpState)(nil) + +// SeqNumValue is a helper routine that allocates a new seqnum.Value value to +// store v and returns a pointer to it. +func SeqNumValue(v seqnum.Value) *seqnum.Value { + return &v +} + +// newTCPState creates a new TCPState. +func newTCPState(out, in TCP) (*tcpState, error) { + portPickerFD, localPort, err := pickPort() if err != nil { - t.Fatalf("can't make new injector: %s", err) + return nil, err + } + s := tcpState{ + out: TCP{SrcPort: &localPort}, + in: TCP{DstPort: &localPort}, + localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), + portPickerFD: portPickerFD, + finSent: false, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err } + return &s, nil +} - newOutgoingTCP := &TCP{ - SrcPort: &localPort, +func (s *tcpState) outgoing() Layer { + newOutgoing := deepcopy.Copy(s.out).(TCP) + if s.localSeqNum != nil { + newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) } - if err := newOutgoingTCP.merge(outgoingTCP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", outgoingTCP, newOutgoingTCP, err) + if s.remoteSeqNum != nil { + newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) } - newIncomingTCP := &TCP{ - DstPort: &localPort, + return &newOutgoing +} + +func (s *tcpState) incoming(received Layer) Layer { + tcpReceived, ok := received.(*TCP) + if !ok { + return nil } - if err := newIncomingTCP.merge(incomingTCP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", incomingTCP, newIncomingTCP, err) + newIn := deepcopy.Copy(s.in).(TCP) + if s.remoteSeqNum != nil { + newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) } - return TCPIPv4{ - outgoing: Layers{ - &Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, - &IPv4{SrcAddr: &lIP, DstAddr: &rIP}, - newOutgoingTCP}, - incoming: Layers{ - &Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, - &IPv4{SrcAddr: &rIP, DstAddr: &lIP}, - newIncomingTCP}, - sniffer: sniffer, - injector: injector, - portPickerFD: portPickerFD, - t: t, - LocalSeqNum: seqnum.Value(rand.Uint32()), + if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 { + // The caller didn't specify an AckNum so we'll expect the calculated one, + // but only if the ACK flag is set because the AckNum is not valid in a + // header if ACK is not set. + newIn.AckNum = Uint32(uint32(*s.localSeqNum)) } + return &newIn } -// Close the injector and sniffer associated with this connection. -func (conn *TCPIPv4) Close() { - conn.sniffer.Close() - conn.injector.Close() - if err := unix.Close(conn.portPickerFD); err != nil { - conn.t.Fatalf("can't close portPickerFD: %s", err) +func (s *tcpState) sent(sent Layer) error { + tcp, ok := sent.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", sent) } - conn.portPickerFD = -1 + if !s.finSent { + // update localSeqNum by the payload only when FIN is not yet sent by us + for current := tcp.next(); current != nil; current = current.next() { + s.localSeqNum.UpdateForward(seqnum.Size(current.length())) + } + } + if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.localSeqNum.UpdateForward(1) + } + if *tcp.Flags&(header.TCPFlagFin) != 0 { + s.finSent = true + } + return nil } -// CreateFrame builds a frame for the connection with tcp overriding defaults -// and additionalLayers added after the TCP header. -func (conn *TCPIPv4) CreateFrame(tcp TCP, additionalLayers ...Layer) Layers { - if tcp.SeqNum == nil { - tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum)) +func (s *tcpState) received(l Layer) error { + tcp, ok := l.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", l) } - if tcp.AckNum == nil { - tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum)) + s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) + if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.remoteSeqNum.UpdateForward(1) } - layersToSend := deepcopy.Copy(conn.outgoing).(Layers) - if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil { - conn.t.Fatalf("can't merge %+v into %+v: %s", tcp, layersToSend[tcpLayerIndex], err) + for current := tcp.next(); current != nil; current = current.next() { + s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) } - layersToSend = append(layersToSend, additionalLayers...) - return layersToSend + return nil } -// SendFrame sends a frame with reasonable defaults. -func (conn *TCPIPv4) SendFrame(frame Layers) { - outBytes, err := frame.toBytes() - if err != nil { - conn.t.Fatalf("can't build outgoing TCP packet: %s", err) +// close frees the port associated with this connection. +func (s *tcpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err } - conn.injector.Send(outBytes) + s.portPickerFD = -1 + return nil +} - // Compute the next TCP sequence number. - for i := tcpLayerIndex + 1; i < len(frame); i++ { - conn.LocalSeqNum.UpdateForward(seqnum.Size(frame[i].length())) +// udpState maintains state about a UDP connection. +type udpState struct { + out, in UDP + portPickerFD int +} + +var _ layerState = (*udpState)(nil) + +// newUDPState creates a new udpState. +func newUDPState(out, in UDP) (*udpState, error) { + portPickerFD, localPort, err := pickPort() + if err != nil { + return nil, err } - tcp := frame[tcpLayerIndex].(*TCP) - if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - conn.LocalSeqNum.UpdateForward(1) + s := udpState{ + out: UDP{SrcPort: &localPort}, + in: UDP{DstPort: &localPort}, + portPickerFD: portPickerFD, + } + if err := s.out.merge(&out); err != nil { + return nil, err } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil } -// Send a packet with reasonable defaults and override some fields by tcp. -func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { - conn.SendFrame(conn.CreateFrame(tcp, additionalLayers...)) +func (s *udpState) outgoing() Layer { + return &s.out +} + +func (s *udpState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*udpState) sent(l Layer) error { + return nil +} + +func (*udpState) received(l Layer) error { + return nil } -// Recv gets a packet from the sniffer within the timeout provided. -// If no packet arrives before the timeout, it returns nil. -func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP { - layers := conn.RecvFrame(timeout) - if tcpLayerIndex < len(layers) { - return layers[tcpLayerIndex].(*TCP) +// close frees the port associated with this connection. +func (s *udpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err } + s.portPickerFD = -1 return nil } -// RecvFrame gets a frame (of type Layers) within the timeout provided. -// If no frame arrives before the timeout, it returns nil. -func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers { - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - break - } - b := conn.sniffer.Recv(timeout) - if b == nil { - break - } - layers, err := ParseEther(b) - if err != nil { - conn.t.Logf("can't parse frame: %s", err) - continue // Ignore packets that can't be parsed. - } - if !conn.incoming.match(layers) { - continue // Ignore packets that don't match the expected incoming. +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer + t *testing.T +} + +// match tries to match each Layer in received against the incoming filter. If +// received is longer than layerStates then that may still count as a match. The +// reverse is never a match. override overrides the default matchers for each +// Layer. +func (conn *Connection) match(override, received Layers) bool { + if len(received) < len(conn.layerStates) { + return false + } + for i, s := range conn.layerStates { + toMatch := s.incoming(received[i]) + if toMatch == nil { + return false } - tcpHeader := (layers[tcpLayerIndex]).(*TCP) - conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum) - if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - conn.RemoteSeqNum.UpdateForward(1) + if i < len(override) { + toMatch.merge(override[i]) } - for i := tcpLayerIndex + 1; i < len(layers); i++ { - conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length())) + if !toMatch.match(received[i]) { + return false } - return layers } - return nil + return true } -// Expect a packet that matches the provided tcp within the timeout specified. -// If it doesn't arrive in time, it returns nil. -func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP { - // We cannot implement this directly using ExpectFrame as we cannot specify - // the Payload part. - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil +// Close frees associated resources held by the Connection. +func (conn *Connection) Close() { + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) + for _, s := range conn.layerStates { + if err := s.close(); err != nil { + errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) } - gotTCP := conn.Recv(timeout) - if tcp.match(gotTCP) { - return gotTCP + } + if errs != nil { + conn.t.Fatalf("unable to close %+v: %s", conn, errs) + } +} + +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + var layersToSend Layers + for _, s := range conn.layerStates { + layersToSend = append(layersToSend, s.outgoing()) + } + if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil { + conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err) + } + layersToSend = append(layersToSend, additionalLayers...) + return layersToSend +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(frame Layers) { + outBytes, err := frame.toBytes() + if err != nil { + conn.t.Fatalf("can't build outgoing TCP packet: %s", err) + } + conn.injector.Send(outBytes) + + // frame might have nil values where the caller wanted to use default values. + // sentFrame will have no nil values in it because it comes from parsing the + // bytes that were actually sent. + sentFrame := parse(parseEther, outBytes) + // Update the state of each layer based on what was sent. + for i, s := range conn.layerStates { + if err := s.sent(sentFrame[i]); err != nil { + conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) } } } -// ExpectFrame expects a frame that matches the specified layers within the +// Send a packet with reasonable defaults. Potentially override the final layer +// in the connection with the provided layer and add additionLayers. +func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) { + conn.SendFrame(conn.CreateFrame(layer, additionalLayers...)) +} + +// recvFrame gets the next successfully parsed frame (of type Layers) within the +// timeout provided. If no parsable frame arrives before the timeout, it returns +// nil. +func (conn *Connection) recvFrame(timeout time.Duration) Layers { + if timeout <= 0 { + return nil + } + b := conn.sniffer.Recv(timeout) + if b == nil { + return nil + } + return parse(parseEther, b) +} + +// Expect a frame with the final layerStates layer matching the provided Layer +// within the timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { + // Make a frame that will ignore all but the final layer. + layers := make([]Layer, len(conn.layerStates)) + layers[len(layers)-1] = layer + + gotFrame, err := conn.ExpectFrame(layers, timeout) + if err != nil { + return nil, err + } + if len(conn.layerStates)-1 < len(gotFrame) { + return gotFrame[len(conn.layerStates)-1], nil + } + conn.t.Fatal("the received frame should be at least as long as the expected layers") + return nil, fmt.Errorf("the received frame should be at least as long as the expected layers") +} + +// ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If it doesn't arrive in time, it returns nil. -func (conn *TCPIPv4) ExpectFrame(layers Layers, timeout time.Duration) Layers { +func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { deadline := time.Now().Add(timeout) + var allLayers []string for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil + var gotLayers Layers + if timeout = time.Until(deadline); timeout > 0 { + gotLayers = conn.recvFrame(timeout) } - gotLayers := conn.RecvFrame(timeout) - if layers.match(gotLayers) { - return gotLayers + if gotLayers == nil { + return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n")) } + if conn.match(layers, gotLayers) { + for i, s := range conn.layerStates { + if err := s.received(gotLayers[i]); err != nil { + conn.t.Fatal(err) + } + } + return gotLayers, nil + } + allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers)) } } -// ExpectData is a convenient method that expects a TCP packet along with -// the payload to arrive within the timeout specified. If it doesn't arrive -// in time, it causes a fatal test failure. -func (conn *TCPIPv4) ExpectData(tcp TCP, data []byte, timeout time.Duration) { - expected := []Layer{&Ether{}, &IPv4{}, &tcp} - if len(data) > 0 { - expected = append(expected, &Payload{Bytes: data}) +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain() { + conn.sniffer.Drain() +} + +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection + +// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. +func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) } - if conn.ExpectFrame(expected, timeout) == nil { - conn.t.Fatalf("expected to get a TCP frame %s with payload %x", &tcp, data) + tcpState, err := newTCPState(outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, } } -// Handshake performs a TCP 3-way handshake. +// Handshake performs a TCP 3-way handshake. The input Connection should have a +// final TCP Layer. func (conn *TCPIPv4) Handshake() { // Send the SYN. conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if conn.SynAck == nil { - conn.t.Fatalf("didn't get synack during handshake") + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if synAck == nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) } -// UDPIPv4 maintains state about a UDP/IPv4 connection. -type UDPIPv4 struct { - outgoing Layers - incoming Layers - sniffer Sniffer - injector Injector - portPickerFD int - t *testing.T +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(expected, timeout) +} + +// Send a packet with reasonable defaults. Potentially override the TCP layer in +// the connection with the provided layer and add additionLayers. +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + (*Connection)(conn).Send(&tcp, additionalLayers...) +} + +// Close frees associated resources held by the TCPIPv4 connection. +func (conn *TCPIPv4) Close() { + (*Connection)(conn).Close() } -// udpLayerIndex is the position of the UDP layer in the UDPIPv4 connection. It -// is the third, after Ethernet and IPv4. -const udpLayerIndex int = 2 +// Expect a frame with the TCP layer matching the provided TCP within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { + layer, err := (*Connection)(conn).Expect(&tcp, timeout) + if layer == nil { + return nil, err + } + gotTCP, ok := layer.(*TCP) + if !ok { + conn.t.Fatalf("expected %s to be TCP", layer) + } + return gotTCP, err +} + +func (conn *TCPIPv4) state() *tcpState { + state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState) + if !ok { + conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates) + } + return state +} + +// RemoteSeqNum returns the next expected sequence number from the DUT. +func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { + return conn.state().remoteSeqNum +} + +// LocalSeqNum returns the next sequence number to send from the testbench. +func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { + return conn.state().localSeqNum +} + +// SynAck returns the SynAck that was part of the handshake. +func (conn *TCPIPv4) SynAck() *TCP { + return conn.state().synAck +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} + +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection // NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { - lMAC, err := tcpip.ParseMACAddress(*localMAC) + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { - t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) + t.Fatalf("can't make etherState: %s", err) } - - rMAC, err := tcpip.ParseMACAddress(*remoteMAC) + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) if err != nil { - t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err) + t.Fatalf("can't make ipv4State: %s", err) } - - portPickerFD, localPort, err := pickPort() + tcpState, err := newUDPState(outgoingUDP, incomingUDP) if err != nil { - t.Fatalf("can't pick a port: %s", err) + t.Fatalf("can't make udpState: %s", err) } - lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) - rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) - - sniffer, err := NewSniffer(t) + injector, err := NewInjector(t) if err != nil { - t.Fatalf("can't make new sniffer: %s", err) + t.Fatalf("can't make injector: %s", err) } - - injector, err := NewInjector(t) + sniffer, err := NewSniffer(t) if err != nil { - t.Fatalf("can't make new injector: %s", err) + t.Fatalf("can't make sniffer: %s", err) } - newOutgoingUDP := &UDP{ - SrcPort: &localPort, - } - if err := newOutgoingUDP.merge(outgoingUDP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", outgoingUDP, newOutgoingUDP, err) - } - newIncomingUDP := &UDP{ - DstPort: &localPort, - } - if err := newIncomingUDP.merge(incomingUDP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", incomingUDP, newIncomingUDP, err) - } return UDPIPv4{ - outgoing: Layers{ - &Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, - &IPv4{SrcAddr: &lIP, DstAddr: &rIP}, - newOutgoingUDP}, - incoming: Layers{ - &Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, - &IPv4{SrcAddr: &rIP, DstAddr: &lIP}, - newIncomingUDP}, - sniffer: sniffer, - injector: injector, - portPickerFD: portPickerFD, - t: t, + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, } } -// Close the injector and sniffer associated with this connection. -func (conn *UDPIPv4) Close() { - conn.sniffer.Close() - conn.injector.Close() - if err := unix.Close(conn.portPickerFD); err != nil { - conn.t.Fatalf("can't close portPickerFD: %s", err) - } - conn.portPickerFD = -1 -} - -// CreateFrame builds a frame for the connection with the provided udp -// overriding defaults and the additionalLayers added after the UDP header. -func (conn *UDPIPv4) CreateFrame(udp UDP, additionalLayers ...Layer) Layers { - layersToSend := deepcopy.Copy(conn.outgoing).(Layers) - if err := layersToSend[udpLayerIndex].(*UDP).merge(udp); err != nil { - conn.t.Fatalf("can't merge %+v into %+v: %s", udp, layersToSend[udpLayerIndex], err) - } - layersToSend = append(layersToSend, additionalLayers...) - return layersToSend +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + return (*Connection)(conn).CreateFrame(layer, additionalLayers...) } -// SendFrame sends a frame with reasonable defaults. +// SendFrame sends a frame on the wire and updates the state of all layers. func (conn *UDPIPv4) SendFrame(frame Layers) { - outBytes, err := frame.toBytes() - if err != nil { - conn.t.Fatalf("can't build outgoing UDP packet: %s", err) - } - conn.injector.Send(outBytes) + (*Connection)(conn).SendFrame(frame) } -// Send a packet with reasonable defaults and override some fields by udp. -func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) { - conn.SendFrame(conn.CreateFrame(udp, additionalLayers...)) -} - -// Recv gets a packet from the sniffer within the timeout provided. If no packet -// arrives before the timeout, it returns nil. -func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP { - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - break - } - b := conn.sniffer.Recv(timeout) - if b == nil { - break - } - layers, err := ParseEther(b) - if err != nil { - conn.t.Logf("can't parse frame: %s", err) - continue // Ignore packets that can't be parsed. - } - if !conn.incoming.match(layers) { - continue // Ignore packets that don't match the expected incoming. - } - return (layers[udpLayerIndex]).(*UDP) - } - return nil +// Close frees associated resources held by the UDPIPv4 connection. +func (conn *UDPIPv4) Close() { + (*Connection)(conn).Close() } -// Expect a packet that matches the provided udp within the timeout specified. -// If it doesn't arrive in time, the test fails. -func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) *UDP { - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil - } - gotUDP := conn.Recv(timeout) - if gotUDP == nil { - return nil - } - if udp.match(gotUDP) { - return gotUDP - } - } +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain() { + conn.sniffer.Drain() } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 4d6625941..5ce324f0d 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -15,6 +15,7 @@ package testbench import ( + "encoding/hex" "fmt" "reflect" "strings" @@ -64,6 +65,9 @@ type Layer interface { // setPrev sets the pointer to the Layer encapsulating this one. setPrev(Layer) + + // merge overrides the values in the interface with the provided values. + merge(Layer) error } // LayerBase is the common elements of all layers. @@ -91,6 +95,9 @@ func (lb *LayerBase) setPrev(l Layer) { // equalLayer compares that two Layer structs match while ignoring field in // which either input has a nil and also ignoring the LayerBase of the inputs. func equalLayer(x, y Layer) bool { + if x == nil || y == nil { + return true + } // opt ignores comparison pairs where either of the inputs is a nil. opt := cmp.FilterValues(func(x, y interface{}) bool { for _, l := range []interface{}{x, y} { @@ -104,6 +111,15 @@ func equalLayer(x, y Layer) bool { return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) } +// mergeLayer merges other in layer. Any non-nil value in other overrides the +// corresponding value in layer. If other is nil, no action is performed. +func mergeLayer(layer, other Layer) error { + if other == nil { + return nil + } + return mergo.Merge(layer, other, mergo.WithOverride) +} + func stringLayer(l Layer) string { v := reflect.ValueOf(l).Elem() t := v.Type() @@ -118,7 +134,12 @@ func stringLayer(l Layer) string { if v.IsNil() { continue } - ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + v = reflect.Indirect(v) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) + } else { + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } } return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) } @@ -153,7 +174,7 @@ func (l *Ether) toBytes() ([]byte, error) { fields.Type = header.IPv4ProtocolNumber default: // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n) + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) } } h.Encode(fields) @@ -172,27 +193,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol return &v } -// ParseEther parses the bytes assuming that they start with an ethernet header +// layerParser parses the input bytes and returns a Layer along with the next +// layerParser to run. If there is no more parsing to do, the returned +// layerParser is nil. +type layerParser func([]byte) (Layer, layerParser) + +// parse parses bytes starting with the first layerParser and using successive +// layerParsers until all the bytes are parsed. +func parse(parser layerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + +// parseEther parses the bytes assuming that they start with an ethernet header // and continues parsing further encapsulations. -func ParseEther(b []byte) (Layers, error) { +func parseEther(b []byte) (Layer, layerParser) { h := header.Ethernet(b) ether := Ether{ SrcAddr: LinkAddress(h.SourceAddress()), DstAddr: LinkAddress(h.DestinationAddress()), Type: NetworkProtocolNumber(h.Type()), } - layers := Layers{ðer} + var nextParser layerParser switch h.Type() { case header.IPv4ProtocolNumber: - moreLayers, err := ParseIPv4(b[ether.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseIPv4 default: - // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %#v", b) + // Assume that the rest is a payload. + nextParser = parsePayload } + return ðer, nextParser } func (l *Ether) match(other Layer) bool { @@ -203,6 +243,12 @@ func (l *Ether) length() int { return header.EthernetMinimumSize } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Ether) merge(other Layer) error { + return mergeLayer(l, other) +} + // IPv4 can construct and match an IPv4 encapsulation. type IPv4 struct { LayerBase @@ -274,7 +320,7 @@ func (l *IPv4) toBytes() ([]byte, error) { fields.Protocol = uint8(header.UDPProtocolNumber) default: // TODO(b/150301488): Support more protocols as needed. - return nil, fmt.Errorf("can't deduce the ip header's next protocol: %#v", n) + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) } } if l.SrcAddr != nil { @@ -311,9 +357,9 @@ func Address(v tcpip.Address) *tcpip.Address { return &v } -// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and +// parseIPv4 parses the bytes assuming that they start with an ipv4 header and // continues parsing further encapsulations. -func ParseIPv4(b []byte) (Layers, error) { +func parseIPv4(b []byte) (Layer, layerParser) { h := header.IPv4(b) tos, _ := h.TOS() ipv4 := IPv4{ @@ -329,22 +375,17 @@ func ParseIPv4(b []byte) (Layers, error) { SrcAddr: Address(h.SourceAddress()), DstAddr: Address(h.DestinationAddress()), } - layers := Layers{&ipv4} + var nextParser layerParser switch h.TransportProtocol() { case header.TCPProtocolNumber: - moreLayers, err := ParseTCP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseTCP case header.UDPProtocolNumber: - moreLayers, err := ParseUDP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseUDP + default: + // Assume that the rest is a payload. + nextParser = parsePayload } - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol()) + return &ipv4, nextParser } func (l *IPv4) match(other Layer) bool { @@ -358,6 +399,12 @@ func (l *IPv4) length() int { return int(*l.IHL) } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + // TCP can construct and match a TCP encapsulation. type TCP struct { LayerBase @@ -468,9 +515,9 @@ func Uint32(v uint32) *uint32 { return &v } -// ParseTCP parses the bytes assuming that they start with a tcp header and +// parseTCP parses the bytes assuming that they start with a tcp header and // continues parsing further encapsulations. -func ParseTCP(b []byte) (Layers, error) { +func parseTCP(b []byte) (Layer, layerParser) { h := header.TCP(b) tcp := TCP{ SrcPort: Uint16(h.SourcePort()), @@ -483,12 +530,7 @@ func ParseTCP(b []byte) (Layers, error) { Checksum: Uint16(h.Checksum()), UrgentPointer: Uint16(h.UrgentPointer()), } - layers := Layers{&tcp} - moreLayers, err := ParsePayload(b[tcp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &tcp, parsePayload } func (l *TCP) match(other Layer) bool { @@ -504,8 +546,8 @@ func (l *TCP) length() int { // merge overrides the values in l with the values from other but only in fields // where the value is not nil. -func (l *TCP) merge(other TCP) error { - return mergo.Merge(l, other, mergo.WithOverride) +func (l *TCP) merge(other Layer) error { + return mergeLayer(l, other) } // UDP can construct and match a UDP encapsulation. @@ -556,9 +598,9 @@ func setUDPChecksum(h *header.UDP, udp *UDP) error { return nil } -// ParseUDP parses the bytes assuming that they start with a udp header and -// continues parsing further encapsulations. -func ParseUDP(b []byte) (Layers, error) { +// parseUDP parses the bytes assuming that they start with a udp header and +// returns the parsed layer and the next parser to use. +func parseUDP(b []byte) (Layer, layerParser) { h := header.UDP(b) udp := UDP{ SrcPort: Uint16(h.SourcePort()), @@ -566,12 +608,7 @@ func ParseUDP(b []byte) (Layers, error) { Length: Uint16(h.Length()), Checksum: Uint16(h.Checksum()), } - layers := Layers{&udp} - moreLayers, err := ParsePayload(b[udp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &udp, parsePayload } func (l *UDP) match(other Layer) bool { @@ -587,8 +624,8 @@ func (l *UDP) length() int { // merge overrides the values in l with the values from other but only in fields // where the value is not nil. -func (l *UDP) merge(other UDP) error { - return mergo.Merge(l, other, mergo.WithOverride) +func (l *UDP) merge(other Layer) error { + return mergeLayer(l, other) } // Payload has bytes beyond OSI layer 4. @@ -601,13 +638,13 @@ func (l *Payload) String() string { return stringLayer(l) } -// ParsePayload parses the bytes assuming that they start with a payload and +// parsePayload parses the bytes assuming that they start with a payload and // continue to the end. There can be no further encapsulations. -func ParsePayload(b []byte) (Layers, error) { +func parsePayload(b []byte) (Layer, layerParser) { payload := Payload{ Bytes: b, } - return Layers{&payload}, nil + return &payload, nil } func (l *Payload) toBytes() ([]byte, error) { @@ -622,18 +659,33 @@ func (l *Payload) length() int { return len(l.Bytes) } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Payload) merge(other Layer) error { + return mergeLayer(l, other) +} + // Layers is an array of Layer and supports similar functions to Layer. type Layers []Layer -func (ls *Layers) toBytes() ([]byte, error) { +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { for i, l := range *ls { if i > 0 { l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) } if i+1 < len(*ls) { l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) } } +} + +func (ls *Layers) toBytes() ([]byte, error) { + ls.linkLayers() outBytes := []byte{} for _, l := range *ls { layerBytes, err := l.toBytes() @@ -649,8 +701,8 @@ func (ls *Layers) match(other Layers) bool { if len(*ls) > len(other) { return false } - for i := 0; i < len(*ls); i++ { - if !equalLayer((*ls)[i], other[i]) { + for i, l := range *ls { + if !equalLayer(l, other[i]) { return false } } diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index b39839625..b32efda93 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -14,7 +14,11 @@ package testbench -import "testing" +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) func TestLayerMatch(t *testing.T) { var nilPayload *Payload @@ -47,3 +51,106 @@ func TestLayerMatch(t *testing.T) { } } } + +func TestLayerStringFormat(t *testing.T) { + for _, tt := range []struct { + name string + l Layer + want string + }{ + { + name: "TCP", + l: &TCP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + SeqNum: Uint32(3452155723), + AckNum: Uint32(2596996163), + DataOffset: Uint8(5), + Flags: Uint8(20), + WindowSize: Uint16(64240), + Checksum: Uint16(0x2e2b), + }, + want: "&testbench.TCP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "SeqNum:3452155723 " + + "AckNum:2596996163 " + + "DataOffset:5 " + + "Flags:20 " + + "WindowSize:64240 " + + "Checksum:11819" + + "}", + }, + { + name: "UDP", + l: &UDP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + Length: Uint16(12), + }, + want: "&testbench.UDP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "Length:12" + + "}", + }, + { + name: "IPv4", + l: &IPv4{ + IHL: Uint8(5), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(0), + Flags: Uint8(2), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(6), + Checksum: Uint16(0x2e2b), + SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), + DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), + }, + want: "&testbench.IPv4{" + + "IHL:5 " + + "TOS:0 " + + "TotalLength:44 " + + "ID:0 " + + "Flags:2 " + + "FragmentOffset:0 " + + "TTL:64 " + + "Protocol:6 " + + "Checksum:11819 " + + "SrcAddr:197.34.63.10 " + + "DstAddr:197.34.63.20" + + "}", + }, + { + name: "Ether", + l: &Ether{ + SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), + DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), + Type: NetworkProtocolNumber(4), + }, + want: "&testbench.Ether{" + + "SrcAddr:02:42:c5:22:3f:0a " + + "DstAddr:02:42:c5:22:3f:14 " + + "Type:4" + + "}", + }, + { + name: "Payload", + l: &Payload{ + Bytes: []byte("Hooray for packetimpact."), + }, + want: "&testbench.Payload{Bytes:\n" + + "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" + + "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" + + "}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.String(); got != tt.want { + t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 0074484f7..ff722d4a6 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -17,6 +17,7 @@ package testbench import ( "encoding/binary" "flag" + "fmt" "math" "net" "testing" @@ -97,12 +98,36 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { } } -// Close the socket that Sniffer is using. -func (s *Sniffer) Close() { +// Drain drains the Sniffer's socket receive buffer by receiving until there's +// nothing else to receive. +func (s *Sniffer) Drain() { + s.t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) + if err != nil { + s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { + s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + } + for { + buf := make([]byte, maxReadSize) + _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { + s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + } +} + +// close the socket that Sniffer is using. +func (s *Sniffer) close() error { if err := unix.Close(s.fd); err != nil { - s.t.Fatalf("can't close sniffer socket: %s", err) + return fmt.Errorf("can't close sniffer socket: %w", err) } s.fd = -1 + return nil } // Injector can inject raw frames. @@ -148,10 +173,11 @@ func (i *Injector) Send(b []byte) { } } -// Close the underlying socket. -func (i *Injector) Close() { +// close the underlying socket. +func (i *Injector) close() error { if err := unix.Close(i.fd); err != nil { - i.t.Fatalf("can't close sniffer socket: %s", err) + return fmt.Errorf("can't close sniffer socket: %w", err) } i.fd = -1 + return nil } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index a9b2de9b9..690cee140 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -40,6 +40,54 @@ packetimpact_go_test( ], ) +packetimpact_go_test( + name = "tcp_outside_the_window", + srcs = ["tcp_outside_the_window_test.go"], + # TODO(eyalsoha): Fix #1607 then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_noaccept_close_rst", + srcs = ["tcp_noaccept_close_rst_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_should_piggyback", + srcs = ["tcp_should_piggyback_test.go"], + # TODO(b/153680566): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_close_wait_ack", + srcs = ["tcp_close_wait_ack_test.go"], + # TODO(b/153574037): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + sh_binary( name = "test_runner", srcs = ["test_runner.sh"], diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 2b3f39045..b98594f94 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -47,20 +47,22 @@ func TestFinWait2Timeout(t *testing.T) { } dut.Close(acceptFd) - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); gotOne == nil { - t.Fatal("expected a FIN-ACK within 1 second but got none") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) } conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) time.Sleep(5 * time.Second) + conn.Drain() + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) if tt.linger2 { - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); gotOne == nil { - t.Fatal("expected a RST packet within a second but got none") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST packet within a second but got none: %s", err) } } else { - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); gotOne != nil { - t.Fatal("expected no RST packets within ten seconds but got one") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); err == nil { + t.Fatalf("expected no RST packets within ten seconds but got one: %s", err) } } }) diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go new file mode 100644 index 000000000..eb4cc7a65 --- /dev/null +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -0,0 +1,102 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_close_wait_ack_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestCloseWaitAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {"OTW", GenerateOTWSeqSegment, 0, false}, + {"OTW", GenerateOTWSeqSegment, 1, true}, + {"OTW", GenerateOTWSeqSegment, 2, true}, + {"ACK", GenerateUnaccACKSegment, 0, false}, + {"ACK", GenerateUnaccACKSegment, 1, true}, + {"ACK", GenerateUnaccACKSegment, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + + // Send a FIN to DUT to intiate the active close + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + + // Send a segment with OTW Seq / unacc ACK and expect an ACK back + conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset), &tb.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(acceptFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the +// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK +// is expected from the receiver. +func GenerateOTWSeqSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + lastAcceptable := conn.LocalSeqNum().Add(windowSize - 1) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{SeqNum: tb.Uint32(otwSeq), Flags: tb.Uint8(header.TCPFlagAck)} +} + +// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated +// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is +// expected from the receiver. +func GenerateUnaccACKSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + lastAcceptable := conn.RemoteSeqNum() + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{AckNum: tb.Uint32(unaccAck), Flags: tb.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go new file mode 100644 index 000000000..7ebdd1950 --- /dev/null +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_noaccept_close_rst_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestTcpNoAcceptCloseReset(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + conn.Handshake() + defer conn.Close() + dut.Close(listenFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + t.Fatalf("expected a RST-ACK packet but got none: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go new file mode 100644 index 000000000..db3d3273b --- /dev/null +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_outside_the_window_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive +// that are inside or outside the TCP window. Packets that are outside the +// window should force an extra ACK, as described in RFC793 page 69: +// https://tools.ietf.org/html/rfc793#page-69 +func TestTCPOutsideTheWindow(t *testing.T) { + for _, tt := range []struct { + description string + tcpFlags uint8 + payload []tb.Layer + seqNumOffset seqnum.Size + expectACK bool + }{ + {"SYN", header.TCPFlagSyn, nil, 0, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true}, + {"ACK", header.TCPFlagAck, nil, 0, false}, + {"FIN", header.TCPFlagFin, nil, 0, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 0, true}, + + {"SYN", header.TCPFlagSyn, nil, 1, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true}, + {"ACK", header.TCPFlagAck, nil, 1, true}, + {"FIN", header.TCPFlagFin, nil, 1, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 1, true}, + + {"SYN", header.TCPFlagSyn, nil, 2, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true}, + {"ACK", header.TCPFlagAck, nil, 2, true}, + {"FIN", header.TCPFlagFin, nil, 2, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Handshake() + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset + conn.Drain() + // Ignore whatever incrementing that this out-of-order packet might cause + // to the AckNum. + localSeqNum := tb.Uint32(uint32(*conn.LocalSeqNum())) + conn.Send(tb.TCP{ + Flags: tb.Uint8(tt.tcpFlags), + SeqNum: tb.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + }, tt.payload...) + timeout := 3 * time.Second + gotACK, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + if tt.expectACK && err != nil { + t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) + } + if !tt.expectACK && gotACK != nil { + t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go new file mode 100644 index 000000000..b0be6ba23 --- /dev/null +++ b/test/packetimpact/tests/tcp_should_piggyback_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_should_piggyback_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestPiggyback(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort, WindowSize: tb.Uint16(12)}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + + dut.Send(acceptFd, sampleData, 0) + expectedTCP := tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} + expectedPayload := tb.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } + + // Cause DUT to send us more data as soon as we ACK their first data segment because we have + // a small window. + dut.Send(acceptFd, sampleData, 0) + + // DUT should ACK our segment by piggybacking ACK to their outstanding data segment instead of + // sending a separate ACK packet. + conn.Send(expectedTCP, &expectedPayload) + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } +} diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go index b48cc6491..c9354074e 100644 --- a/test/packetimpact/tests/tcp_window_shrink_test.go +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -38,15 +38,22 @@ func TestWindowShrink(t *testing.T) { dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") + samplePayload := &tb.Payload{Bytes: sampleData} dut.Send(acceptFd, sampleData, 0) - conn.ExpectData(tb.TCP{}, sampleData, time.Second) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) dut.Send(acceptFd, sampleData, 0) dut.Send(acceptFd, sampleData, 0) - conn.ExpectData(tb.TCP{}, sampleData, time.Second) - conn.ExpectData(tb.TCP{}, sampleData, time.Second) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } // We close our receiving window here conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)}) @@ -54,5 +61,8 @@ func TestWindowShrink(t *testing.T) { // Note: There is another kind of zero-window probing which Windows uses (by sending one // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change // the following lines. - conn.ExpectData(tb.TCP{SeqNum: tb.Uint32(uint32(conn.RemoteSeqNum - 1))}, nil, time.Second) + expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 + if _, err := conn.ExpectData(&tb.TCP{SeqNum: tb.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %v: %s", expectedRemoteSeqNum, err) + } } diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go index bc1b0be49..61fd17050 100644 --- a/test/packetimpact/tests/udp_recv_multicast_test.go +++ b/test/packetimpact/tests/udp_recv_multicast_test.go @@ -30,7 +30,7 @@ func TestUDPRecvMulticast(t *testing.T) { defer dut.Close(boundFD) conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort}) defer conn.Close() - frame := conn.CreateFrame(tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")}) + frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")}) frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4())) conn.SendFrame(frame) dut.Recv(boundFD, 100, 0) |