// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package testbench has utilities to send and receive packets and also command // the DUT to run POSIX functions. package testbench import ( "flag" "fmt" "math/rand" "net" "testing" "time" "github.com/mohae/deepcopy" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) var localIPv4 = flag.String("local_ipv4", "", "local IPv4 address for test packets") var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test packets") var localMAC = flag.String("local_mac", "", "local mac address for test packets") var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets") // TCPIPv4 maintains state about a TCP/IPv4 connection. type TCPIPv4 struct { outgoing Layers incoming Layers LocalSeqNum seqnum.Value RemoteSeqNum seqnum.Value SynAck *TCP sniffer Sniffer injector Injector portPickerFD int t *testing.T } // pickPort makes a new socket and returns the socket FD and port. The caller // must close the FD when done with the port if there is no error. func pickPort() (int, uint16, error) { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0) if err != nil { return -1, 0, err } var sa unix.SockaddrInet4 copy(sa.Addr[0:4], net.ParseIP(*localIPv4).To4()) if err := unix.Bind(fd, &sa); err != nil { unix.Close(fd) return -1, 0, err } newSockAddr, err := unix.Getsockname(fd) if err != nil { unix.Close(fd) return -1, 0, err } newSockAddrInet4, ok := newSockAddr.(*unix.SockaddrInet4) if !ok { unix.Close(fd) return -1, 0, fmt.Errorf("can't cast Getsockname result to SockaddrInet4") } return fd, uint16(newSockAddrInet4.Port), nil } // tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It // is the third, after Ethernet and IPv4. const tcpLayerIndex int = 2 // NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. func NewTCPIPv4(t *testing.T, dut DUT, outgoingTCP, incomingTCP TCP) TCPIPv4 { lMAC, err := tcpip.ParseMACAddress(*localMAC) if err != nil { t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) } rMAC, err := tcpip.ParseMACAddress(*remoteMAC) if err != nil { t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err) } portPickerFD, localPort, err := pickPort() if err != nil { t.Fatalf("can't pick a port: %s", err) } lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) sniffer, err := NewSniffer(t) if err != nil { t.Fatalf("can't make new sniffer: %s", err) } injector, err := NewInjector(t) if err != nil { t.Fatalf("can't make new injector: %s", err) } newOutgoingTCP := &TCP{ DataOffset: Uint8(header.TCPMinimumSize), WindowSize: Uint16(32768), SrcPort: &localPort, } if err := newOutgoingTCP.merge(outgoingTCP); err != nil { t.Fatalf("can't merge %v into %v: %s", outgoingTCP, newOutgoingTCP, err) } newIncomingTCP := &TCP{ DstPort: &localPort, } if err := newIncomingTCP.merge(incomingTCP); err != nil { t.Fatalf("can't merge %v into %v: %s", incomingTCP, newIncomingTCP, err) } return TCPIPv4{ outgoing: Layers{ &Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, &IPv4{SrcAddr: &lIP, DstAddr: &rIP}, newOutgoingTCP}, incoming: Layers{ &Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, &IPv4{SrcAddr: &rIP, DstAddr: &lIP}, newIncomingTCP}, sniffer: sniffer, injector: injector, portPickerFD: portPickerFD, t: t, LocalSeqNum: seqnum.Value(rand.Uint32()), } } // Close the injector and sniffer associated with this connection. func (conn *TCPIPv4) Close() { conn.sniffer.Close() conn.injector.Close() if err := unix.Close(conn.portPickerFD); err != nil { conn.t.Fatalf("can't close portPickerFD: %s", err) } conn.portPickerFD = -1 } // Send a packet with reasonable defaults and override some fields by tcp. func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { if tcp.SeqNum == nil { tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum)) } if tcp.AckNum == nil { tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum)) } layersToSend := deepcopy.Copy(conn.outgoing).(Layers) if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil { conn.t.Fatalf("can't merge %v into %v: %s", tcp, layersToSend[tcpLayerIndex], err) } layersToSend = append(layersToSend, additionalLayers...) outBytes, err := layersToSend.toBytes() if err != nil { conn.t.Fatalf("can't build outgoing TCP packet: %s", err) } conn.injector.Send(outBytes) // Compute the next TCP sequence number. for i := tcpLayerIndex + 1; i < len(layersToSend); i++ { conn.LocalSeqNum.UpdateForward(seqnum.Size(layersToSend[i].length())) } if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { conn.LocalSeqNum.UpdateForward(1) } } // Recv gets a packet from the sniffer within the timeout provided. If no packet // arrives before the timeout, it returns nil. func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP { deadline := time.Now().Add(timeout) for { timeout = deadline.Sub(time.Now()) if timeout <= 0 { break } b := conn.sniffer.Recv(timeout) if b == nil { break } layers, err := ParseEther(b) if err != nil { continue // Ignore packets that can't be parsed. } if !conn.incoming.match(layers) { continue // Ignore packets that don't match the expected incoming. } tcpHeader := (layers[tcpLayerIndex]).(*TCP) conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum) if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { conn.RemoteSeqNum.UpdateForward(1) } for i := tcpLayerIndex + 1; i < len(layers); i++ { conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length())) } return tcpHeader } return nil } // Expect a packet that matches the provided tcp within the timeout specified. // If it doesn't arrive in time, the test fails. func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP { deadline := time.Now().Add(timeout) for { timeout = deadline.Sub(time.Now()) if timeout <= 0 { return nil } gotTCP := conn.Recv(timeout) if gotTCP == nil { return nil } if tcp.match(gotTCP) { return gotTCP } } } // Handshake performs a TCP 3-way handshake. func (conn *TCPIPv4) Handshake() { // Send the SYN. conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if conn.SynAck == nil { conn.t.Fatalf("didn't get synack during handshake") } // Send an ACK. conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) }