diff options
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r-- | test/packetimpact/testbench/BUILD | 45 | ||||
-rw-r--r-- | test/packetimpact/testbench/connections.go | 1222 | ||||
-rw-r--r-- | test/packetimpact/testbench/dut.go | 798 | ||||
-rw-r--r-- | test/packetimpact/testbench/dut_client.go | 28 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers.go | 1595 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers_test.go | 728 | ||||
-rw-r--r-- | test/packetimpact/testbench/rawsockets.go | 207 | ||||
-rw-r--r-- | test/packetimpact/testbench/testbench.go | 157 |
8 files changed, 0 insertions, 4780 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD deleted file mode 100644 index 43b4c7ca1..000000000 --- a/test/packetimpact/testbench/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package( - licenses = ["notice"], -) - -go_library( - name = "testbench", - srcs = [ - "connections.go", - "dut.go", - "dut_client.go", - "layers.go", - "rawsockets.go", - "testbench.go", - ], - visibility = ["//test/packetimpact:__subpackages__"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - "//pkg/usermem", - "//test/packetimpact/proto:posix_server_go_proto", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - "@com_github_mohae_deepcopy//:go_default_library", - "@org_golang_google_grpc//:go_default_library", - "@org_golang_google_grpc//keepalive:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - "@org_uber_go_multierr//:go_default_library", - ], -) - -go_test( - name = "testbench_test", - size = "small", - srcs = ["layers_test.go"], - library = ":testbench", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "@com_github_mohae_deepcopy//:go_default_library", - ], -) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go deleted file mode 100644 index 15e1a51de..000000000 --- a/test/packetimpact/testbench/connections.go +++ /dev/null @@ -1,1222 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testbench - -import ( - "fmt" - "math/rand" - "testing" - "time" - - "github.com/mohae/deepcopy" - "go.uber.org/multierr" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { - switch sa := sa.(type) { - case *unix.SockaddrInet4: - return uint16(sa.Port), nil - case *unix.SockaddrInet6: - return uint16(sa.Port), nil - } - return 0, fmt.Errorf("sockaddr type %T does not contain port", sa) -} - -// pickPort makes a new socket and returns the socket FD and port. The domain -// should be AF_INET or AF_INET6. The caller must close the FD when done with -// the port if there is no error. -func (n *DUTTestNet) pickPort(domain, typ int) (fd int, port uint16, err error) { - fd, err = unix.Socket(domain, typ, 0) - if err != nil { - return -1, 0, fmt.Errorf("creating socket: %w", err) - } - defer func() { - if err != nil { - if cerr := unix.Close(fd); cerr != nil { - err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr)) - } - } - }() - var sa unix.Sockaddr - switch domain { - case unix.AF_INET: - var sa4 unix.SockaddrInet4 - copy(sa4.Addr[:], n.LocalIPv4) - sa = &sa4 - case unix.AF_INET6: - sa6 := unix.SockaddrInet6{ZoneId: n.LocalDevID} - copy(sa6.Addr[:], n.LocalIPv6) - sa = &sa6 - default: - return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain) - } - if err = unix.Bind(fd, sa); err != nil { - return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err) - } - sa, err = unix.Getsockname(fd) - if err != nil { - return -1, 0, fmt.Errorf("unix.Getsocketname(%d): %w", fd, err) - } - port, err = portFromSockaddr(sa) - if err != nil { - return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err) - } - return fd, port, nil -} - -// layerState stores the state of a layer of a connection. -type layerState interface { - // outgoing returns an outgoing layer to be sent in a frame. It should not - // update layerState, that is done in layerState.sent. - outgoing() Layer - - // incoming creates an expected Layer for comparing against a received Layer. - // Because the expectation can depend on values in the received Layer, it is - // an input to incoming. For example, the ACK number needs to be checked in a - // TCP packet but only if the ACK flag is set in the received packet. It - // should not update layerState, that is done in layerState.received. The - // caller takes ownership of the returned Layer. - incoming(received Layer) Layer - - // sent updates the layerState based on the Layer that was sent. The input is - // a Layer with all prev and next pointers populated so that the entire frame - // as it was sent is available. - sent(sent Layer) error - - // received updates the layerState based on a Layer that is received. The - // input is a Layer with all prev and next pointers populated so that the - // entire frame as it was received is available. - received(received Layer) error - - // close frees associated resources held by the LayerState. - close() error -} - -// etherState maintains state about an Ethernet connection. -type etherState struct { - out, in Ether -} - -var _ layerState = (*etherState)(nil) - -// newEtherState creates a new etherState. -func (n *DUTTestNet) newEtherState(out, in Ether) (*etherState, error) { - lmac := tcpip.LinkAddress(n.LocalMAC) - rmac := tcpip.LinkAddress(n.RemoteMAC) - s := etherState{ - out: Ether{SrcAddr: &lmac, DstAddr: &rmac}, - in: Ether{SrcAddr: &rmac, DstAddr: &lmac}, - } - if err := s.out.merge(&out); err != nil { - return nil, err - } - if err := s.in.merge(&in); err != nil { - return nil, err - } - return &s, nil -} - -func (s *etherState) outgoing() Layer { - return deepcopy.Copy(&s.out).(Layer) -} - -// incoming implements layerState.incoming. -func (s *etherState) incoming(Layer) Layer { - return deepcopy.Copy(&s.in).(Layer) -} - -func (*etherState) sent(Layer) error { - return nil -} - -func (*etherState) received(Layer) error { - return nil -} - -func (*etherState) close() error { - return nil -} - -// ipv4State maintains state about an IPv4 connection. -type ipv4State struct { - out, in IPv4 -} - -var _ layerState = (*ipv4State)(nil) - -// newIPv4State creates a new ipv4State. -func (n *DUTTestNet) newIPv4State(out, in IPv4) (*ipv4State, error) { - lIP := tcpip.Address(n.LocalIPv4) - rIP := tcpip.Address(n.RemoteIPv4) - s := ipv4State{ - out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, - in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, - } - if err := s.out.merge(&out); err != nil { - return nil, err - } - if err := s.in.merge(&in); err != nil { - return nil, err - } - return &s, nil -} - -func (s *ipv4State) outgoing() Layer { - return deepcopy.Copy(&s.out).(Layer) -} - -// incoming implements layerState.incoming. -func (s *ipv4State) incoming(Layer) Layer { - return deepcopy.Copy(&s.in).(Layer) -} - -func (*ipv4State) sent(Layer) error { - return nil -} - -func (*ipv4State) received(Layer) error { - return nil -} - -func (*ipv4State) close() error { - return nil -} - -// ipv6State maintains state about an IPv6 connection. -type ipv6State struct { - out, in IPv6 -} - -var _ layerState = (*ipv6State)(nil) - -// newIPv6State creates a new ipv6State. -func (n *DUTTestNet) newIPv6State(out, in IPv6) (*ipv6State, error) { - lIP := tcpip.Address(n.LocalIPv6) - rIP := tcpip.Address(n.RemoteIPv6) - s := ipv6State{ - out: IPv6{SrcAddr: &lIP, DstAddr: &rIP}, - in: IPv6{SrcAddr: &rIP, DstAddr: &lIP}, - } - if err := s.out.merge(&out); err != nil { - return nil, err - } - if err := s.in.merge(&in); err != nil { - return nil, err - } - return &s, nil -} - -// outgoing returns an outgoing layer to be sent in a frame. -func (s *ipv6State) outgoing() Layer { - return deepcopy.Copy(&s.out).(Layer) -} - -func (s *ipv6State) incoming(Layer) Layer { - return deepcopy.Copy(&s.in).(Layer) -} - -func (s *ipv6State) sent(Layer) error { - // Nothing to do. - return nil -} - -func (s *ipv6State) received(Layer) error { - // Nothing to do. - return nil -} - -// close cleans up any resources held. -func (s *ipv6State) close() error { - return nil -} - -// tcpState maintains state about a TCP connection. -type tcpState struct { - out, in TCP - localSeqNum, remoteSeqNum *seqnum.Value - synAck *TCP - portPickerFD int - finSent bool -} - -var _ layerState = (*tcpState)(nil) - -// SeqNumValue is a helper routine that allocates a new seqnum.Value value to -// store v and returns a pointer to it. -func SeqNumValue(v seqnum.Value) *seqnum.Value { - return &v -} - -// newTCPState creates a new TCPState. -func (n *DUTTestNet) newTCPState(domain int, out, in TCP) (*tcpState, error) { - portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_STREAM) - if err != nil { - return nil, err - } - s := tcpState{ - out: TCP{SrcPort: &localPort}, - in: TCP{DstPort: &localPort}, - localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), - portPickerFD: portPickerFD, - finSent: false, - } - if err := s.out.merge(&out); err != nil { - return nil, err - } - if err := s.in.merge(&in); err != nil { - return nil, err - } - return &s, nil -} - -func (s *tcpState) outgoing() Layer { - newOutgoing := deepcopy.Copy(s.out).(TCP) - if s.localSeqNum != nil { - newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) - } - if s.remoteSeqNum != nil { - newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) - } - return &newOutgoing -} - -// incoming implements layerState.incoming. -func (s *tcpState) incoming(received Layer) Layer { - tcpReceived, ok := received.(*TCP) - if !ok { - return nil - } - newIn := deepcopy.Copy(s.in).(TCP) - if s.remoteSeqNum != nil { - newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) - } - if seq, flags := s.localSeqNum, tcpReceived.Flags; seq != nil && flags != nil && *flags&header.TCPFlagAck != 0 { - // The caller didn't specify an AckNum so we'll expect the calculated one, - // but only if the ACK flag is set because the AckNum is not valid in a - // header if ACK is not set. - newIn.AckNum = Uint32(uint32(*seq)) - } - return &newIn -} - -func (s *tcpState) sent(sent Layer) error { - tcp, ok := sent.(*TCP) - if !ok { - return fmt.Errorf("can't update tcpState with %T Layer", sent) - } - if !s.finSent { - // update localSeqNum by the payload only when FIN is not yet sent by us - for current := tcp.next(); current != nil; current = current.next() { - s.localSeqNum.UpdateForward(seqnum.Size(current.length())) - } - } - if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - s.localSeqNum.UpdateForward(1) - } - if *tcp.Flags&(header.TCPFlagFin) != 0 { - s.finSent = true - } - return nil -} - -func (s *tcpState) received(l Layer) error { - tcp, ok := l.(*TCP) - if !ok { - return fmt.Errorf("can't update tcpState with %T Layer", l) - } - s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) - if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - s.remoteSeqNum.UpdateForward(1) - } - for current := tcp.next(); current != nil; current = current.next() { - s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) - } - return nil -} - -// close frees the port associated with this connection. -func (s *tcpState) close() error { - if err := unix.Close(s.portPickerFD); err != nil { - return err - } - s.portPickerFD = -1 - return nil -} - -// udpState maintains state about a UDP connection. -type udpState struct { - out, in UDP - portPickerFD int -} - -var _ layerState = (*udpState)(nil) - -// newUDPState creates a new udpState. -func (n *DUTTestNet) newUDPState(domain int, out, in UDP) (*udpState, error) { - portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_DGRAM) - if err != nil { - return nil, fmt.Errorf("picking port: %w", err) - } - s := udpState{ - out: UDP{SrcPort: &localPort}, - in: UDP{DstPort: &localPort}, - portPickerFD: portPickerFD, - } - if err := s.out.merge(&out); err != nil { - return nil, err - } - if err := s.in.merge(&in); err != nil { - return nil, err - } - return &s, nil -} - -func (s *udpState) outgoing() Layer { - return deepcopy.Copy(&s.out).(Layer) -} - -// incoming implements layerState.incoming. -func (s *udpState) incoming(Layer) Layer { - return deepcopy.Copy(&s.in).(Layer) -} - -func (*udpState) sent(l Layer) error { - return nil -} - -func (*udpState) received(l Layer) error { - return nil -} - -// close frees the port associated with this connection. -func (s *udpState) close() error { - if err := unix.Close(s.portPickerFD); err != nil { - return err - } - s.portPickerFD = -1 - return nil -} - -// Connection holds a collection of layer states for maintaining a connection -// along with sockets for sniffer and injecting packets. -type Connection struct { - layerStates []layerState - injector Injector - sniffer Sniffer -} - -// Returns the default incoming frame against which to match. If received is -// longer than layerStates then that may still count as a match. The reverse is -// never a match and nil is returned. -func (conn *Connection) incoming(received Layers) Layers { - if len(received) < len(conn.layerStates) { - return nil - } - in := Layers{} - for i, s := range conn.layerStates { - toMatch := s.incoming(received[i]) - if toMatch == nil { - return nil - } - in = append(in, toMatch) - } - return in -} - -func (conn *Connection) match(override, received Layers) bool { - toMatch := conn.incoming(received) - if toMatch == nil { - return false // Not enough layers in gotLayers for matching. - } - if err := toMatch.merge(override); err != nil { - return false // Failing to merge is not matching. - } - return toMatch.match(received) -} - -// Close frees associated resources held by the Connection. -func (conn *Connection) Close(t *testing.T) { - t.Helper() - - errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) - for _, s := range conn.layerStates { - if err := s.close(); err != nil { - errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) - } - } - if errs != nil { - t.Fatalf("unable to close %+v: %s", conn, errs) - } -} - -// CreateFrame builds a frame for the connection with defaults overridden -// from the innermost layer out, and additionalLayers added after it. -// -// Note that overrideLayers can have a length that is less than the number -// of layers in this connection, and in such cases the innermost layers are -// overridden first. As an example, valid values of overrideLayers for a TCP- -// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and -// [Ethernet, IPv4, TCP]. -func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers { - t.Helper() - - var layersToSend Layers - for i, s := range conn.layerStates { - layer := s.outgoing() - // overrideLayers and conn.layerStates have their tails aligned, so - // to find the index we move backwards by the distance i is to the - // end. - if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 { - if err := layer.merge(overrideLayers[j]); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) - } - } - layersToSend = append(layersToSend, layer) - } - layersToSend = append(layersToSend, additionalLayers...) - return layersToSend -} - -// SendFrameStateless sends a frame without updating any of the layer states. -// -// This method is useful for sending out-of-band control messages such as -// ICMP packets, where it would not make sense to update the transport layer's -// state using the ICMP header. -func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) { - t.Helper() - - outBytes, err := frame.ToBytes() - if err != nil { - t.Fatalf("can't build outgoing packet: %s", err) - } - conn.injector.Send(t, outBytes) -} - -// SendFrame sends a frame on the wire and updates the state of all layers. -func (conn *Connection) SendFrame(t *testing.T, frame Layers) { - t.Helper() - - outBytes, err := frame.ToBytes() - if err != nil { - t.Fatalf("can't build outgoing packet: %s", err) - } - conn.injector.Send(t, outBytes) - - // frame might have nil values where the caller wanted to use default values. - // sentFrame will have no nil values in it because it comes from parsing the - // bytes that were actually sent. - sentFrame := parse(parseEther, outBytes) - // Update the state of each layer based on what was sent. - for i, s := range conn.layerStates { - if err := s.sent(sentFrame[i]); err != nil { - t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) - } - } -} - -// send sends a packet, possibly with layers of this connection overridden and -// additional layers added. -// -// Types defined with Connection as the underlying type should expose -// type-safe versions of this method. -func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { - t.Helper() - - conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...)) -} - -// recvFrame gets the next successfully parsed frame (of type Layers) within the -// timeout provided. If no parsable frame arrives before the timeout, it returns -// nil. -func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers { - t.Helper() - - if timeout <= 0 { - return nil - } - b := conn.sniffer.Recv(t, timeout) - if b == nil { - return nil - } - return parse(parseEther, b) -} - -// layersError stores the Layers that we got and the Layers that we wanted to -// match. -type layersError struct { - got, want Layers -} - -func (e *layersError) Error() string { - return e.got.diff(e.want) -} - -// Expect expects a frame with the final layerStates layer matching the -// provided Layer within the timeout specified. If it doesn't arrive in time, -// an error is returned. -func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) { - t.Helper() - - // Make a frame that will ignore all but the final layer. - layers := make([]Layer, len(conn.layerStates)) - layers[len(layers)-1] = layer - - gotFrame, err := conn.ExpectFrame(t, layers, timeout) - if err != nil { - return nil, err - } - if len(conn.layerStates)-1 < len(gotFrame) { - return gotFrame[len(conn.layerStates)-1], nil - } - t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame) - panic("unreachable") -} - -// ExpectFrame expects a frame that matches the provided Layers within the -// timeout specified. If one arrives in time, the Layers is returned without an -// error. If it doesn't arrive in time, it returns nil and error is non-nil. -func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) { - t.Helper() - - deadline := time.Now().Add(timeout) - var errs error - for { - var gotLayers Layers - if timeout := time.Until(deadline); timeout > 0 { - gotLayers = conn.recvFrame(t, timeout) - } - if gotLayers == nil { - if errs == nil { - return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout) - } - return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout) - } - if conn.match(layers, gotLayers) { - for i, s := range conn.layerStates { - if err := s.received(gotLayers[i]); err != nil { - t.Fatalf("failed to update test connection's layer states based on received frame: %s", err) - } - } - return gotLayers, nil - } - want := conn.incoming(layers) - if err := want.merge(layers); err != nil { - errs = multierr.Combine(errs, err) - } else { - errs = multierr.Combine(errs, &layersError{got: gotLayers, want: want}) - } - } -} - -// Drain drains the sniffer's receive buffer by receiving packets until there's -// nothing else to receive. -func (conn *Connection) Drain(t *testing.T) { - t.Helper() - - conn.sniffer.Drain(t) -} - -// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. -type TCPIPv4 struct { - Connection -} - -// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. -func (n *DUTTestNet) NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { - t.Helper() - - etherState, err := n.newEtherState(Ether{}, Ether{}) - if err != nil { - t.Fatalf("can't make etherState: %s", err) - } - ipv4State, err := n.newIPv4State(IPv4{}, IPv4{}) - if err != nil { - t.Fatalf("can't make ipv4State: %s", err) - } - tcpState, err := n.newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) - if err != nil { - t.Fatalf("can't make tcpState: %s", err) - } - injector, err := n.NewInjector(t) - if err != nil { - t.Fatalf("can't make injector: %s", err) - } - sniffer, err := n.NewSniffer(t) - if err != nil { - t.Fatalf("can't make sniffer: %s", err) - } - - return TCPIPv4{ - Connection: Connection{ - layerStates: []layerState{etherState, ipv4State, tcpState}, - injector: injector, - sniffer: sniffer, - }, - } -} - -// Connect performs a TCP 3-way handshake. The input Connection should have a -// final TCP Layer. -func (conn *TCPIPv4) Connect(t *testing.T) { - t.Helper() - - // Send the SYN. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)}) - - // Wait for the SYN-ACK. - synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if err != nil { - t.Fatalf("didn't get synack during handshake: %s", err) - } - conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck - - // Send an ACK. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) -} - -// ConnectWithOptions performs a TCP 3-way handshake with given TCP options. -// The input Connection should have a final TCP Layer. -func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) { - t.Helper() - - // Send the SYN. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) - - // Wait for the SYN-ACK. - synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if err != nil { - t.Fatalf("didn't get synack during handshake: %s", err) - } - conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck - - // Send an ACK. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) -} - -// ExpectData is a convenient method that expects a Layer and the Layer after -// it. If it doesn't arrive in time, it returns nil. -func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { - t.Helper() - - expected := make([]Layer, len(conn.layerStates)) - expected[len(expected)-1] = tcp - if payload != nil { - expected = append(expected, payload) - } - return conn.ExpectFrame(t, expected, timeout) -} - -// ExpectNextData attempts to receive the next incoming segment for the -// connection and expects that to match the given layers. -// -// It differs from ExpectData() in that here we are only interested in the next -// received segment, while ExpectData() can receive multiple segments for the -// connection until there is a match with given layers or a timeout. -func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { - t.Helper() - - // Receive the first incoming TCP segment for this connection. - got, err := conn.ExpectData(t, &TCP{}, nil, timeout) - if err != nil { - return nil, err - } - - expected := make([]Layer, len(conn.layerStates)) - expected[len(expected)-1] = tcp - if payload != nil { - expected = append(expected, payload) - tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length())) - } - if !conn.match(expected, got) { - return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) - } - return got, nil -} - -// Send a packet with reasonable defaults. Potentially override the TCP layer in -// the connection with the provided layer and add additionLayers. -func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) { - t.Helper() - - conn.send(t, Layers{&tcp}, additionalLayers...) -} - -// Expect expects a frame with the TCP layer matching the provided TCP within -// the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) { - t.Helper() - - layer, err := conn.Connection.Expect(t, &tcp, timeout) - if layer == nil { - return nil, err - } - gotTCP, ok := layer.(*TCP) - if !ok { - t.Fatalf("expected %s to be TCP", layer) - } - return gotTCP, err -} - -func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState { - t.Helper() - - state, ok := conn.layerStates[2].(*tcpState) - if !ok { - t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) - } - return state -} - -func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State { - t.Helper() - - state, ok := conn.layerStates[1].(*ipv4State) - if !ok { - t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) - } - return state -} - -// RemoteSeqNum returns the next expected sequence number from the DUT. -func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value { - t.Helper() - - return conn.tcpState(t).remoteSeqNum -} - -// LocalSeqNum returns the next sequence number to send from the testbench. -func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value { - t.Helper() - - return conn.tcpState(t).localSeqNum -} - -// SynAck returns the SynAck that was part of the handshake. -func (conn *TCPIPv4) SynAck(t *testing.T) *TCP { - t.Helper() - - return conn.tcpState(t).synAck -} - -// LocalAddr gets the local socket address of this connection. -func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { - t.Helper() - - sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) - return sa -} - -// IPv4Conn maintains the state for all the layers in a IPv4 connection. -type IPv4Conn struct { - Connection -} - -// NewIPv4Conn creates a new IPv4Conn connection with reasonable defaults. -func (n *DUTTestNet) NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn { - t.Helper() - - etherState, err := n.newEtherState(Ether{}, Ether{}) - if err != nil { - t.Fatalf("can't make EtherState: %s", err) - } - ipv4State, err := n.newIPv4State(outgoingIPv4, incomingIPv4) - if err != nil { - t.Fatalf("can't make IPv4State: %s", err) - } - - injector, err := n.NewInjector(t) - if err != nil { - t.Fatalf("can't make injector: %s", err) - } - sniffer, err := n.NewSniffer(t) - if err != nil { - t.Fatalf("can't make sniffer: %s", err) - } - - return IPv4Conn{ - Connection: Connection{ - layerStates: []layerState{etherState, ipv4State}, - injector: injector, - sniffer: sniffer, - }, - } -} - -// Send sends a frame with ipv4 overriding the IPv4 layer defaults and -// additionalLayers added after it. -func (c *IPv4Conn) Send(t *testing.T, ipv4 IPv4, additionalLayers ...Layer) { - t.Helper() - - c.send(t, Layers{&ipv4}, additionalLayers...) -} - -// IPv6Conn maintains the state for all the layers in a IPv6 connection. -type IPv6Conn struct { - Connection -} - -// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults. -func (n *DUTTestNet) NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { - t.Helper() - - etherState, err := n.newEtherState(Ether{}, Ether{}) - if err != nil { - t.Fatalf("can't make EtherState: %s", err) - } - ipv6State, err := n.newIPv6State(outgoingIPv6, incomingIPv6) - if err != nil { - t.Fatalf("can't make IPv6State: %s", err) - } - - injector, err := n.NewInjector(t) - if err != nil { - t.Fatalf("can't make injector: %s", err) - } - sniffer, err := n.NewSniffer(t) - if err != nil { - t.Fatalf("can't make sniffer: %s", err) - } - - return IPv6Conn{ - Connection: Connection{ - layerStates: []layerState{etherState, ipv6State}, - injector: injector, - sniffer: sniffer, - }, - } -} - -// Send sends a frame with ipv6 overriding the IPv6 layer defaults and -// additionalLayers added after it. -func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) { - t.Helper() - - conn.send(t, Layers{&ipv6}, additionalLayers...) -} - -// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. -type UDPIPv4 struct { - Connection -} - -// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. -func (n *DUTTestNet) NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { - t.Helper() - - etherState, err := n.newEtherState(Ether{}, Ether{}) - if err != nil { - t.Fatalf("can't make etherState: %s", err) - } - ipv4State, err := n.newIPv4State(IPv4{}, IPv4{}) - if err != nil { - t.Fatalf("can't make ipv4State: %s", err) - } - udpState, err := n.newUDPState(unix.AF_INET, outgoingUDP, incomingUDP) - if err != nil { - t.Fatalf("can't make udpState: %s", err) - } - injector, err := n.NewInjector(t) - if err != nil { - t.Fatalf("can't make injector: %s", err) - } - sniffer, err := n.NewSniffer(t) - if err != nil { - t.Fatalf("can't make sniffer: %s", err) - } - - return UDPIPv4{ - Connection: Connection{ - layerStates: []layerState{etherState, ipv4State, udpState}, - injector: injector, - sniffer: sniffer, - }, - } -} - -func (conn *UDPIPv4) udpState(t *testing.T) *udpState { - t.Helper() - - state, ok := conn.layerStates[2].(*udpState) - if !ok { - t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) - } - return state -} - -func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State { - t.Helper() - - state, ok := conn.layerStates[1].(*ipv4State) - if !ok { - t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) - } - return state -} - -// LocalAddr gets the local socket address of this connection. -func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { - t.Helper() - - sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) - return sa -} - -// SrcPort returns the source port of this connection. -func (conn *UDPIPv4) SrcPort(t *testing.T) uint16 { - t.Helper() - - return *conn.udpState(t).out.SrcPort -} - -// Send sends a packet with reasonable defaults, potentially overriding the UDP -// layer and adding additionLayers. -func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { - t.Helper() - - conn.send(t, Layers{&udp}, additionalLayers...) -} - -// SendIP sends a packet with reasonable defaults, potentially overriding the -// UDP and IPv4 headers and adding additionLayers. -func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) { - t.Helper() - - conn.send(t, Layers{&ip, &udp}, additionalLayers...) -} - -// SendFrame sends a frame on the wire and updates the state of all layers. -func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { - conn.send(t, overrideLayers, additionalLayers...) -} - -// Expect expects a frame with the UDP layer matching the provided UDP within -// the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { - t.Helper() - - layer, err := conn.Connection.Expect(t, &udp, timeout) - if err != nil { - return nil, err - } - gotUDP, ok := layer.(*UDP) - if !ok { - t.Fatalf("expected %s to be UDP", layer) - } - return gotUDP, nil -} - -// ExpectData is a convenient method that expects a Layer and the Layer after -// it. If it doesn't arrive in time, it returns nil. -func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { - t.Helper() - - expected := make([]Layer, len(conn.layerStates)) - expected[len(expected)-1] = &udp - if payload.length() != 0 { - expected = append(expected, &payload) - } - return conn.ExpectFrame(t, expected, timeout) -} - -// UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection. -type UDPIPv6 struct { - Connection -} - -// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults. -func (n *DUTTestNet) NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 { - t.Helper() - - etherState, err := n.newEtherState(Ether{}, Ether{}) - if err != nil { - t.Fatalf("can't make etherState: %s", err) - } - ipv6State, err := n.newIPv6State(IPv6{}, IPv6{}) - if err != nil { - t.Fatalf("can't make IPv6State: %s", err) - } - udpState, err := n.newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP) - if err != nil { - t.Fatalf("can't make udpState: %s", err) - } - injector, err := n.NewInjector(t) - if err != nil { - t.Fatalf("can't make injector: %s", err) - } - sniffer, err := n.NewSniffer(t) - if err != nil { - t.Fatalf("can't make sniffer: %s", err) - } - return UDPIPv6{ - Connection: Connection{ - layerStates: []layerState{etherState, ipv6State, udpState}, - injector: injector, - sniffer: sniffer, - }, - } -} - -func (conn *UDPIPv6) udpState(t *testing.T) *udpState { - t.Helper() - - state, ok := conn.layerStates[2].(*udpState) - if !ok { - t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) - } - return state -} - -func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State { - t.Helper() - - state, ok := conn.layerStates[1].(*ipv6State) - if !ok { - t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1]) - } - return state -} - -// LocalAddr gets the local socket address of this connection. -func (conn *UDPIPv6) LocalAddr(t *testing.T, zoneID uint32) *unix.SockaddrInet6 { - t.Helper() - - sa := &unix.SockaddrInet6{ - Port: int(*conn.udpState(t).out.SrcPort), - // Local address is in perspective to the remote host, so it's scoped to the - // ID of the remote interface. - ZoneId: zoneID, - } - copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr) - return sa -} - -// SrcPort returns the source port of this connection. -func (conn *UDPIPv6) SrcPort(t *testing.T) uint16 { - t.Helper() - - return *conn.udpState(t).out.SrcPort -} - -// Send sends a packet with reasonable defaults, potentially overriding the UDP -// layer and adding additionLayers. -func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { - t.Helper() - - conn.send(t, Layers{&udp}, additionalLayers...) -} - -// SendIPv6 sends a packet with reasonable defaults, potentially overriding the -// UDP and IPv6 headers and adding additionLayers. -func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) { - t.Helper() - - conn.send(t, Layers{&ip, &udp}, additionalLayers...) -} - -// SendFrame sends a frame on the wire and updates the state of all layers. -func (conn *UDPIPv6) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { - conn.send(t, overrideLayers, additionalLayers...) -} - -// Expect expects a frame with the UDP layer matching the provided UDP within -// the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { - t.Helper() - - layer, err := conn.Connection.Expect(t, &udp, timeout) - if err != nil { - return nil, err - } - gotUDP, ok := layer.(*UDP) - if !ok { - t.Fatalf("expected %s to be UDP", layer) - } - return gotUDP, nil -} - -// ExpectData is a convenient method that expects a Layer and the Layer after -// it. If it doesn't arrive in time, it returns nil. -func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { - t.Helper() - - expected := make([]Layer, len(conn.layerStates)) - expected[len(expected)-1] = &udp - if payload.length() != 0 { - expected = append(expected, &payload) - } - return conn.ExpectFrame(t, expected, timeout) -} - -// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection. -type TCPIPv6 struct { - Connection -} - -// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults. -func (n *DUTTestNet) NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 { - etherState, err := n.newEtherState(Ether{}, Ether{}) - if err != nil { - t.Fatalf("can't make etherState: %s", err) - } - ipv6State, err := n.newIPv6State(IPv6{}, IPv6{}) - if err != nil { - t.Fatalf("can't make ipv6State: %s", err) - } - tcpState, err := n.newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP) - if err != nil { - t.Fatalf("can't make tcpState: %s", err) - } - injector, err := n.NewInjector(t) - if err != nil { - t.Fatalf("can't make injector: %s", err) - } - sniffer, err := n.NewSniffer(t) - if err != nil { - t.Fatalf("can't make sniffer: %s", err) - } - - return TCPIPv6{ - Connection: Connection{ - layerStates: []layerState{etherState, ipv6State, tcpState}, - injector: injector, - sniffer: sniffer, - }, - } -} - -// SrcPort returns the source port from the given Connection. -func (conn *TCPIPv6) SrcPort() uint16 { - state := conn.layerStates[2].(*tcpState) - return *state.out.SrcPort -} - -// ExpectData is a convenient method that expects a Layer and the Layer after -// it. If it doesn't arrive in time, it returns nil. -func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { - t.Helper() - - expected := make([]Layer, len(conn.layerStates)) - expected[len(expected)-1] = tcp - if payload != nil { - expected = append(expected, payload) - } - return conn.ExpectFrame(t, expected, timeout) -} diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go deleted file mode 100644 index aeae5e5d3..000000000 --- a/test/packetimpact/testbench/dut.go +++ /dev/null @@ -1,798 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testbench - -import ( - "context" - "encoding/binary" - "fmt" - "net" - "syscall" - "testing" - "time" - - pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" - - "golang.org/x/sys/unix" - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" -) - -// DUT communicates with the DUT to force it to make POSIX calls. -type DUT struct { - conn *grpc.ClientConn - posixServer POSIXClient - Net *DUTTestNet - Uname *DUTUname -} - -// NewDUT creates a new connection with the DUT over gRPC. -func NewDUT(t *testing.T) DUT { - t.Helper() - info := getDUTInfo() - dut := info.ConnectToDUT(t) - t.Cleanup(func() { - dut.TearDownConnection() - info.release() - }) - return dut -} - -// ConnectToDUT connects to DUT through gRPC. -func (info *DUTInfo) ConnectToDUT(t *testing.T) DUT { - t.Helper() - - n := info.Net - posixServerAddress := net.JoinHostPort(n.POSIXServerIP.String(), fmt.Sprintf("%d", n.POSIXServerPort)) - conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive})) - if err != nil { - t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err) - } - posixServer := NewPOSIXClient(conn) - return DUT{ - conn: conn, - posixServer: posixServer, - Net: n, - Uname: info.Uname, - } -} - -// TearDownConnection closes the underlying connection. -func (dut *DUT) TearDownConnection() { - dut.conn.Close() -} - -func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr { - t.Helper() - - switch s := sa.(type) { - case *unix.SockaddrInet4: - return &pb.Sockaddr{ - Sockaddr: &pb.Sockaddr_In{ - In: &pb.SockaddrIn{ - Family: unix.AF_INET, - Port: uint32(s.Port), - Addr: s.Addr[:], - }, - }, - } - case *unix.SockaddrInet6: - return &pb.Sockaddr{ - Sockaddr: &pb.Sockaddr_In6{ - In6: &pb.SockaddrIn6{ - Family: unix.AF_INET6, - Port: uint32(s.Port), - Flowinfo: 0, - ScopeId: s.ZoneId, - Addr: s.Addr[:], - }, - }, - } - } - t.Fatalf("can't parse Sockaddr struct: %+v", sa) - return nil -} - -func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr { - t.Helper() - - switch s := sa.Sockaddr.(type) { - case *pb.Sockaddr_In: - ret := unix.SockaddrInet4{ - Port: int(s.In.GetPort()), - } - copy(ret.Addr[:], s.In.GetAddr()) - return &ret - case *pb.Sockaddr_In6: - ret := unix.SockaddrInet6{ - Port: int(s.In6.GetPort()), - ZoneId: s.In6.GetScopeId(), - } - copy(ret.Addr[:], s.In6.GetAddr()) - return &ret - } - t.Fatalf("can't parse Sockaddr proto: %#v", sa) - return nil -} - -// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol -// proto, and bound to the IP address addr. Returns the new file descriptor and -// the port that was selected on the DUT. -func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) { - t.Helper() - - var fd int32 - if addr.To4() != nil { - fd = dut.Socket(t, unix.AF_INET, typ, proto) - sa := unix.SockaddrInet4{} - copy(sa.Addr[:], addr.To4()) - dut.Bind(t, fd, &sa) - } else if addr.To16() != nil { - fd = dut.Socket(t, unix.AF_INET6, typ, proto) - sa := unix.SockaddrInet6{} - copy(sa.Addr[:], addr.To16()) - sa.ZoneId = dut.Net.RemoteDevID - dut.Bind(t, fd, &sa) - } else { - t.Fatalf("invalid IP address: %s", addr) - } - sa := dut.GetSockName(t, fd) - var port int - switch s := sa.(type) { - case *unix.SockaddrInet4: - port = s.Port - case *unix.SockaddrInet6: - port = s.Port - default: - t.Fatalf("unknown sockaddr type from getsockname: %T", sa) - } - return fd, uint16(port) -} - -// CreateListener makes a new TCP connection. If it fails, the test ends. -func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) { - t.Helper() - - fd, remotePort := dut.CreateBoundSocket(t, typ, proto, dut.Net.RemoteIPv4) - dut.Listen(t, fd, backlog) - return fd, remotePort -} - -// All the functions that make gRPC calls to the POSIX service are below, sorted -// alphabetically. - -// Accept calls accept on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over the timeout or error handling is needed, use -// AcceptWithErrno. -func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd) - if fd < 0 { - t.Fatalf("failed to accept: %s", err) - } - return fd, sa -} - -// AcceptWithErrno calls accept on the DUT. -func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { - t.Helper() - - req := &pb.AcceptRequest{ - Sockfd: sockfd, - } - resp, err := dut.posixServer.Accept(ctx, req) - if err != nil { - t.Fatalf("failed to call Accept: %s", err) - } - return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) -} - -// Bind calls bind on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over the timeout or error handling is -// needed, use BindWithErrno. -func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.BindWithErrno(ctx, t, fd, sa) - if ret != 0 { - t.Fatalf("failed to bind socket: %s", err) - } -} - -// BindWithErrno calls bind on the DUT. -func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { - t.Helper() - - req := &pb.BindRequest{ - Sockfd: fd, - Addr: dut.sockaddrToProto(t, sa), - } - resp, err := dut.posixServer.Bind(ctx, req) - if err != nil { - t.Fatalf("failed to call Bind: %s", err) - } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) -} - -// Close calls close on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over the timeout or error handling is needed, use -// CloseWithErrno. -func (dut *DUT) Close(t *testing.T, fd int32) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.CloseWithErrno(ctx, t, fd) - if ret != 0 { - t.Fatalf("failed to close: %s", err) - } -} - -// CloseWithErrno calls close on the DUT. -func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) { - t.Helper() - - req := &pb.CloseRequest{ - Fd: fd, - } - resp, err := dut.posixServer.Close(ctx, req) - if err != nil { - t.Fatalf("failed to call Close: %s", err) - } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) -} - -// Connect calls connect on the DUT and causes a fatal test failure if it -// doesn't succeed. If more control over the timeout or error handling is -// needed, use ConnectWithErrno. -func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.ConnectWithErrno(ctx, t, fd, sa) - // Ignore 'operation in progress' error that can be returned when the socket - // is non-blocking. - if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { - t.Fatalf("failed to connect socket: %s", err) - } -} - -// ConnectWithErrno calls bind on the DUT. -func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { - t.Helper() - - req := &pb.ConnectRequest{ - Sockfd: fd, - Addr: dut.sockaddrToProto(t, sa), - } - resp, err := dut.posixServer.Connect(ctx, req) - if err != nil { - t.Fatalf("failed to call Connect: %s", err) - } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) -} - -// GetSockName calls getsockname on the DUT and causes a fatal test failure if -// it doesn't succeed. If more control over the timeout or error handling is -// needed, use GetSockNameWithErrno. -func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd) - if ret != 0 { - t.Fatalf("failed to getsockname: %s", err) - } - return sa -} - -// GetSockNameWithErrno calls getsockname on the DUT. -func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { - t.Helper() - - req := &pb.GetSockNameRequest{ - Sockfd: sockfd, - } - resp, err := dut.posixServer.GetSockName(ctx, req) - if err != nil { - t.Fatalf("failed to call Bind: %s", err) - } - return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) -} - -func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { - t.Helper() - - req := &pb.GetSockOptRequest{ - Sockfd: sockfd, - Level: level, - Optname: optname, - Optlen: optlen, - Type: typ, - } - resp, err := dut.posixServer.GetSockOpt(ctx, req) - if err != nil { - t.Fatalf("failed to call GetSockOpt: %s", err) - } - optval := resp.GetOptval() - if optval == nil { - t.Fatalf("GetSockOpt response does not contain a value") - } - return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_()) -} - -// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it -// doesn't succeed. If more control over the timeout or error handling is -// needed, use GetSockOptWithErrno. Because endianess and the width of values -// might differ between the testbench and DUT architectures, prefer to use a -// more specific GetSockOptXxx function. -func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen) - if ret != 0 { - t.Fatalf("failed to GetSockOpt: %s", err) - } - return optval -} - -// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the -// width of values might differ between the testbench and DUT architectures, -// prefer to use a more specific GetSockOptXxxWithErrno function. -func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) { - t.Helper() - - ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) - bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval) - if !ok { - t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val) - } - return ret, bytesval.Bytesval, errno -} - -// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure -// if it doesn't succeed. If more control over the int optval or error handling -// is needed, use GetSockOptIntWithErrno. -func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname) - if ret != 0 { - t.Fatalf("failed to GetSockOptInt: %s", err) - } - return intval -} - -// GetSockOptIntWithErrno calls getsockopt with an integer optval. -func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) { - t.Helper() - - ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) - intval, ok := optval.Val.(*pb.SockOptVal_Intval) - if !ok { - t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val) - } - return ret, intval.Intval, errno -} - -// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure -// if it doesn't succeed. If more control over the timeout or error handling is -// needed, use GetSockOptTimevalWithErrno. -func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname) - if ret != 0 { - t.Fatalf("failed to GetSockOptTimeval: %s", err) - } - return timeval -} - -// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval. -func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) { - t.Helper() - - ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) - tv, ok := optval.Val.(*pb.SockOptVal_Timeval) - if !ok { - t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val) - } - timeval := unix.Timeval{ - Sec: tv.Timeval.Seconds, - Usec: tv.Timeval.Microseconds, - } - return ret, timeval, errno -} - -// Listen calls listen on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over the timeout or error handling is needed, use -// ListenWithErrno. -func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog) - if ret != 0 { - t.Fatalf("failed to listen: %s", err) - } -} - -// ListenWithErrno calls listen on the DUT. -func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) { - t.Helper() - - req := &pb.ListenRequest{ - Sockfd: sockfd, - Backlog: backlog, - } - resp, err := dut.posixServer.Listen(ctx, req) - if err != nil { - t.Fatalf("failed to call Listen: %s", err) - } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) -} - -// PollOne calls poll on the DUT and asserts that the expected event must be -// signaled on the given fd within the given timeout. -func (dut *DUT) PollOne(t *testing.T, fd int32, events int16, timeout time.Duration) { - t.Helper() - - pfds := dut.Poll(t, []unix.PollFd{{Fd: fd, Events: events}}, timeout) - if n := len(pfds); n != 1 { - t.Fatalf("Poll returned %d ready file descriptors, expected 1", n) - } - if readyFd := pfds[0].Fd; readyFd != fd { - t.Fatalf("Poll returned an fd %d that was not requested (%d)", readyFd, fd) - } - if got, want := pfds[0].Revents, int16(events); got&want == 0 { - t.Fatalf("Poll returned no events in our interest, got: %#b, want: %#b", got, want) - } -} - -// Poll calls poll on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over error handling is needed, use PollWithErrno. -// Only pollfds with non-empty revents are returned, the only way to tie the -// response back to the original request is using the fd number. -func (dut *DUT) Poll(t *testing.T, pfds []unix.PollFd, timeout time.Duration) []unix.PollFd { - t.Helper() - - ctx := context.Background() - var cancel context.CancelFunc - if timeout >= 0 { - ctx, cancel = context.WithTimeout(ctx, timeout+RPCTimeout) - defer cancel() - } - ret, result, err := dut.PollWithErrno(ctx, t, pfds, timeout) - if ret < 0 { - t.Fatalf("failed to poll: %s", err) - } - return result -} - -// PollWithErrno calls poll on the DUT. -func (dut *DUT) PollWithErrno(ctx context.Context, t *testing.T, pfds []unix.PollFd, timeout time.Duration) (int32, []unix.PollFd, error) { - t.Helper() - - req := &pb.PollRequest{ - TimeoutMillis: int32(timeout.Milliseconds()), - } - for _, pfd := range pfds { - req.Pfds = append(req.Pfds, &pb.PollFd{ - Fd: pfd.Fd, - Events: uint32(pfd.Events), - }) - } - resp, err := dut.posixServer.Poll(ctx, req) - if err != nil { - t.Fatalf("failed to call Poll: %s", err) - } - if ret, npfds := resp.GetRet(), len(resp.GetPfds()); ret >= 0 && int(ret) != npfds { - t.Fatalf("nonsensical poll response: ret(%d) != len(pfds)(%d)", ret, npfds) - } - var result []unix.PollFd - for _, protoPfd := range resp.GetPfds() { - result = append(result, unix.PollFd{ - Fd: protoPfd.GetFd(), - Revents: int16(protoPfd.GetEvents()), - }) - } - return resp.GetRet(), result, syscall.Errno(resp.GetErrno_()) -} - -// Send calls send on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over the timeout or error handling is needed, use -// SendWithErrno. -func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags) - if ret == -1 { - t.Fatalf("failed to send: %s", err) - } - return ret -} - -// SendWithErrno calls send on the DUT. -func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) { - t.Helper() - - req := &pb.SendRequest{ - Sockfd: sockfd, - Buf: buf, - Flags: flags, - } - resp, err := dut.posixServer.Send(ctx, req) - if err != nil { - t.Fatalf("failed to call Send: %s", err) - } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) -} - -// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over the timeout or error handling is needed, use -// SendToWithErrno. -func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr) - if ret == -1 { - t.Fatalf("failed to sendto: %s", err) - } - return ret -} - -// SendToWithErrno calls sendto on the DUT. -func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { - t.Helper() - - req := &pb.SendToRequest{ - Sockfd: sockfd, - Buf: buf, - Flags: flags, - DestAddr: dut.sockaddrToProto(t, destAddr), - } - resp, err := dut.posixServer.SendTo(ctx, req) - if err != nil { - t.Fatalf("failed to call SendTo: %s", err) - } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) -} - -// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking -// is true, otherwise it will clear the flag. -func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) { - t.Helper() - - req := &pb.SetNonblockingRequest{ - Fd: fd, - Nonblocking: nonblocking, - } - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - - resp, err := dut.posixServer.SetNonblocking(ctx, req) - if err != nil { - t.Fatalf("failed to call SetNonblocking: %s", err) - } - if resp.GetRet() == -1 { - t.Fatalf("fcntl(%d, %s) failed: %s", fd, resp.GetCmd(), syscall.Errno(resp.GetErrno_())) - } -} - -func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { - t.Helper() - - req := &pb.SetSockOptRequest{ - Sockfd: sockfd, - Level: level, - Optname: optname, - Optval: optval, - } - resp, err := dut.posixServer.SetSockOpt(ctx, req) - if err != nil { - t.Fatalf("failed to call SetSockOpt: %s", err) - } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) -} - -// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it -// doesn't succeed. If more control over the timeout or error handling is -// needed, use SetSockOptWithErrno. Because endianess and the width of values -// might differ between the testbench and DUT architectures, prefer to use a -// more specific SetSockOptXxx function. -func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval) - if ret != 0 { - t.Fatalf("failed to SetSockOpt: %s", err) - } -} - -// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the -// width of values might differ between the testbench and DUT architectures, -// prefer to use a more specific SetSockOptXxxWithErrno function. -func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) { - t.Helper() - - return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) -} - -// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure -// if it doesn't succeed. If more control over the int optval or error handling -// is needed, use SetSockOptIntWithErrno. -func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval) - if ret != 0 { - t.Fatalf("failed to SetSockOptInt: %s", err) - } -} - -// SetSockOptIntWithErrno calls setsockopt with an integer optval. -func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) { - t.Helper() - - return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) -} - -// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure -// if it doesn't succeed. If more control over the timeout or error handling is -// needed, use SetSockOptTimevalWithErrno. -func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv) - if ret != 0 { - t.Fatalf("failed to SetSockOptTimeval: %s", err) - } -} - -// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to -// bytes. -func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { - t.Helper() - - timeval := pb.Timeval{ - Seconds: int64(tv.Sec), - Microseconds: int64(tv.Usec), - } - return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) -} - -// Socket calls socket on the DUT and returns the file descriptor. If socket -// fails on the DUT, the test ends. -func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 { - t.Helper() - - fd, err := dut.SocketWithErrno(t, domain, typ, proto) - if fd < 0 { - t.Fatalf("failed to create socket: %s", err) - } - return fd -} - -// SocketWithErrno calls socket on the DUT and returns the fd and errno. -func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) { - t.Helper() - - req := &pb.SocketRequest{ - Domain: domain, - Type: typ, - Protocol: proto, - } - ctx := context.Background() - resp, err := dut.posixServer.Socket(ctx, req) - if err != nil { - t.Fatalf("failed to call Socket: %s", err) - } - return resp.GetFd(), syscall.Errno(resp.GetErrno_()) -} - -// Recv calls recv on the DUT and causes a fatal test failure if it doesn't -// succeed. If more control over the timeout or error handling is needed, use -// RecvWithErrno. -func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags) - if ret == -1 { - t.Fatalf("failed to recv: %s", err) - } - return buf -} - -// RecvWithErrno calls recv on the DUT. -func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) { - t.Helper() - - req := &pb.RecvRequest{ - Sockfd: sockfd, - Len: len, - Flags: flags, - } - resp, err := dut.posixServer.Recv(ctx, req) - if err != nil { - t.Fatalf("failed to call Recv: %s", err) - } - return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) -} - -// SetSockLingerOption sets SO_LINGER socket option on the DUT. -func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Duration, enable bool) { - var linger unix.Linger - if enable { - linger.Onoff = 1 - } - linger.Linger = int32(timeout / time.Second) - - buf := make([]byte, 8) - binary.LittleEndian.PutUint32(buf, uint32(linger.Onoff)) - binary.LittleEndian.PutUint32(buf[4:], uint32(linger.Linger)) - dut.SetSockOpt(t, sockfd, unix.SOL_SOCKET, unix.SO_LINGER, buf) -} - -// Shutdown calls shutdown on the DUT and causes a fatal test failure if it -// doesn't succeed. If more control over the timeout or error handling is -// needed, use ShutdownWithErrno. -func (dut *DUT) Shutdown(t *testing.T, fd, how int32) error { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - return dut.ShutdownWithErrno(ctx, t, fd, how) -} - -// ShutdownWithErrno calls shutdown on the DUT. -func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) error { - t.Helper() - - req := &pb.ShutdownRequest{ - Fd: fd, - How: how, - } - resp, err := dut.posixServer.Shutdown(ctx, req) - if err != nil { - t.Fatalf("failed to call Shutdown: %s", err) - } - return syscall.Errno(resp.GetErrno_()) -} diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go deleted file mode 100644 index 0fc3d97b4..000000000 --- a/test/packetimpact/testbench/dut_client.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testbench - -import ( - "google.golang.org/grpc" - pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" -) - -// POSIXClient is a gRPC client for the Posix service. -type POSIXClient pb.PosixClient - -// NewPOSIXClient makes a new gRPC client for the POSIX service. -func NewPOSIXClient(c grpc.ClientConnInterface) POSIXClient { - return pb.NewPosixClient(c) -} diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go deleted file mode 100644 index 64a7a171a..000000000 --- a/test/packetimpact/testbench/layers.go +++ /dev/null @@ -1,1595 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testbench - -import ( - "encoding/binary" - "encoding/hex" - "fmt" - "reflect" - "strings" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "go.uber.org/multierr" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// Layer is the interface that all encapsulations must implement. -// -// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A -// Layer contains all the fields of the encapsulation. Each field is a pointer -// and may be nil. -type Layer interface { - fmt.Stringer - - // ToBytes converts the Layer into bytes. In places where the Layer's field - // isn't nil, the value that is pointed to is used. When the field is nil, a - // reasonable default for the Layer is used. For example, "64" for IPv4 TTL - // and a calculated checksum for TCP or IP. Some layers require information - // from the previous or next layers in order to compute a default, such as - // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list - // to the layer's neighbors. - ToBytes() ([]byte, error) - - // match checks if the current Layer matches the provided Layer. If either - // Layer has a nil in a given field, that field is considered matching. - // Otherwise, the values pointed to by the fields must match. The LayerBase is - // ignored. - match(Layer) bool - - // length in bytes of the current encapsulation - length() int - - // next gets a pointer to the encapsulated Layer. - next() Layer - - // prev gets a pointer to the Layer encapsulating this one. - Prev() Layer - - // setNext sets the pointer to the encapsulated Layer. - setNext(Layer) - - // setPrev sets the pointer to the Layer encapsulating this one. - setPrev(Layer) - - // merge overrides the values in the interface with the provided values. - merge(Layer) error -} - -// LayerBase is the common elements of all layers. -type LayerBase struct { - nextLayer Layer - prevLayer Layer -} - -func (lb *LayerBase) next() Layer { - return lb.nextLayer -} - -// Prev returns the previous layer. -func (lb *LayerBase) Prev() Layer { - return lb.prevLayer -} - -func (lb *LayerBase) setNext(l Layer) { - lb.nextLayer = l -} - -func (lb *LayerBase) setPrev(l Layer) { - lb.prevLayer = l -} - -// equalLayer compares that two Layer structs match while ignoring field in -// which either input has a nil and also ignoring the LayerBase of the inputs. -func equalLayer(x, y Layer) bool { - if x == nil || y == nil { - return true - } - // opt ignores comparison pairs where either of the inputs is a nil. - opt := cmp.FilterValues(func(x, y interface{}) bool { - for _, l := range []interface{}{x, y} { - v := reflect.ValueOf(l) - if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() { - return true - } - } - return false - }, cmp.Ignore()) - return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) -} - -// mergeLayer merges y into x. Any fields for which y has a non-nil value, that -// value overwrite the corresponding fields in x. -func mergeLayer(x, y Layer) error { - if y == nil { - return nil - } - if reflect.TypeOf(x) != reflect.TypeOf(y) { - return fmt.Errorf("can't merge %T into %T", y, x) - } - vx := reflect.ValueOf(x).Elem() - vy := reflect.ValueOf(y).Elem() - t := vy.Type() - for i := 0; i < vy.NumField(); i++ { - t := t.Field(i) - if t.Anonymous { - // Ignore the LayerBase in the Layer struct. - continue - } - v := vy.Field(i) - if v.IsNil() { - continue - } - vx.Field(i).Set(v) - } - return nil -} - -func stringLayer(l Layer) string { - v := reflect.ValueOf(l).Elem() - t := v.Type() - var ret []string - for i := 0; i < v.NumField(); i++ { - t := t.Field(i) - if t.Anonymous { - // Ignore the LayerBase in the Layer struct. - continue - } - v := v.Field(i) - if v.IsNil() { - continue - } - v = reflect.Indirect(v) - if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { - ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) - } else { - ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) - } - } - return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) -} - -// Ether can construct and match an ethernet encapsulation. -type Ether struct { - LayerBase - SrcAddr *tcpip.LinkAddress - DstAddr *tcpip.LinkAddress - Type *tcpip.NetworkProtocolNumber -} - -func (l *Ether) String() string { - return stringLayer(l) -} - -// ToBytes implements Layer.ToBytes. -func (l *Ether) ToBytes() ([]byte, error) { - b := make([]byte, header.EthernetMinimumSize) - h := header.Ethernet(b) - fields := &header.EthernetFields{} - if l.SrcAddr != nil { - fields.SrcAddr = *l.SrcAddr - } - if l.DstAddr != nil { - fields.DstAddr = *l.DstAddr - } - if l.Type != nil { - fields.Type = *l.Type - } else { - switch n := l.next().(type) { - case *IPv4: - fields.Type = header.IPv4ProtocolNumber - case *IPv6: - fields.Type = header.IPv6ProtocolNumber - default: - return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) - } - } - h.Encode(fields) - return h, nil -} - -// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value -// to store v and returns a pointer to it. -func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress { - return &v -} - -// NetworkProtocolNumber is a helper routine that allocates a new -// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it. -func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber { - return &v -} - -// layerParser parses the input bytes and returns a Layer along with the next -// layerParser to run. If there is no more parsing to do, the returned -// layerParser is nil. -type layerParser func([]byte) (Layer, layerParser) - -// parse parses bytes starting with the first layerParser and using successive -// layerParsers until all the bytes are parsed. -func parse(parser layerParser, b []byte) Layers { - var layers Layers - for { - var layer Layer - layer, parser = parser(b) - layers = append(layers, layer) - if parser == nil { - break - } - b = b[layer.length():] - } - layers.linkLayers() - return layers -} - -// parseEther parses the bytes assuming that they start with an ethernet header -// and continues parsing further encapsulations. -func parseEther(b []byte) (Layer, layerParser) { - h := header.Ethernet(b) - ether := Ether{ - SrcAddr: LinkAddress(h.SourceAddress()), - DstAddr: LinkAddress(h.DestinationAddress()), - Type: NetworkProtocolNumber(h.Type()), - } - var nextParser layerParser - switch h.Type() { - case header.IPv4ProtocolNumber: - nextParser = parseIPv4 - case header.IPv6ProtocolNumber: - nextParser = parseIPv6 - default: - // Assume that the rest is a payload. - nextParser = parsePayload - } - return ðer, nextParser -} - -func (l *Ether) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *Ether) length() int { - return header.EthernetMinimumSize -} - -// merge implements Layer.merge. -func (l *Ether) merge(other Layer) error { - return mergeLayer(l, other) -} - -// IPv4 can construct and match an IPv4 encapsulation. -type IPv4 struct { - LayerBase - IHL *uint8 - TOS *uint8 - TotalLength *uint16 - ID *uint16 - Flags *uint8 - FragmentOffset *uint16 - TTL *uint8 - Protocol *uint8 - Checksum *uint16 - SrcAddr *tcpip.Address - DstAddr *tcpip.Address - Options *header.IPv4Options -} - -func (l *IPv4) String() string { - return stringLayer(l) -} - -// ToBytes implements Layer.ToBytes. -func (l *IPv4) ToBytes() ([]byte, error) { - // An IPv4 header is variable length depending on the size of the Options. - hdrLen := header.IPv4MinimumSize - if l.Options != nil { - if len(*l.Options)%4 != 0 { - return nil, fmt.Errorf("invalid header options '%x (len=%d)'; must be 32 bit aligned", *l.Options, len(*l.Options)) - } - hdrLen += len(*l.Options) - if hdrLen > header.IPv4MaximumHeaderSize { - return nil, fmt.Errorf("IPv4 Options %d bytes, Max %d", len(*l.Options), header.IPv4MaximumOptionsSize) - } - } - b := make([]byte, hdrLen) - h := header.IPv4(b) - fields := &header.IPv4Fields{ - TOS: 0, - TotalLength: 0, - ID: 0, - Flags: 0, - FragmentOffset: 0, - TTL: 64, - Protocol: 0, - Checksum: 0, - SrcAddr: tcpip.Address(""), - DstAddr: tcpip.Address(""), - Options: nil, - } - if l.TOS != nil { - fields.TOS = *l.TOS - } - if l.TotalLength != nil { - fields.TotalLength = *l.TotalLength - } else { - fields.TotalLength = uint16(l.length()) - current := l.next() - for current != nil { - fields.TotalLength += uint16(current.length()) - current = current.next() - } - } - if l.ID != nil { - fields.ID = *l.ID - } - if l.Flags != nil { - fields.Flags = *l.Flags - } - if l.FragmentOffset != nil { - fields.FragmentOffset = *l.FragmentOffset - } - if l.TTL != nil { - fields.TTL = *l.TTL - } - if l.Protocol != nil { - fields.Protocol = *l.Protocol - } else { - switch n := l.next().(type) { - case *TCP: - fields.Protocol = uint8(header.TCPProtocolNumber) - case *UDP: - fields.Protocol = uint8(header.UDPProtocolNumber) - case *ICMPv4: - fields.Protocol = uint8(header.ICMPv4ProtocolNumber) - default: - // TODO(b/150301488): Support more protocols as needed. - return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) - } - } - if l.SrcAddr != nil { - fields.SrcAddr = *l.SrcAddr - } - if l.DstAddr != nil { - fields.DstAddr = *l.DstAddr - } - - h.Encode(fields) - - // Put raw option bytes from test definition in header. Options as raw bytes - // allows us to serialize malformed options, which is not possible with - // the provided serialization functions. - if l.Options != nil { - h.SetHeaderLength(h.HeaderLength() + uint8(len(*l.Options))) - if got, want := copy(h.Options(), *l.Options), len(*l.Options); got != want { - return nil, fmt.Errorf("failed to copy option bytes into header, got %d want %d", got, want) - } - } - - // Encode cannot set this incorrectly so we need to overwrite what it wrote - // in order to test handling of a bad IHL value. - if l.IHL != nil { - h.SetHeaderLength(*l.IHL) - } - - if l.Checksum == nil { - h.SetChecksum(^h.CalculateChecksum()) - } else { - h.SetChecksum(*l.Checksum) - } - - return h, nil -} - -// Uint16 is a helper routine that allocates a new -// uint16 value to store v and returns a pointer to it. -func Uint16(v uint16) *uint16 { - return &v -} - -// Uint8 is a helper routine that allocates a new -// uint8 value to store v and returns a pointer to it. -func Uint8(v uint8) *uint8 { - return &v -} - -// Address is a helper routine that allocates a new tcpip.Address value to -// store v and returns a pointer to it. -func Address(v tcpip.Address) *tcpip.Address { - return &v -} - -// parseIPv4 parses the bytes assuming that they start with an ipv4 header and -// continues parsing further encapsulations. -func parseIPv4(b []byte) (Layer, layerParser) { - h := header.IPv4(b) - options := h.Options() - tos, _ := h.TOS() - ipv4 := IPv4{ - IHL: Uint8(h.HeaderLength()), - TOS: &tos, - TotalLength: Uint16(h.TotalLength()), - ID: Uint16(h.ID()), - Flags: Uint8(h.Flags()), - FragmentOffset: Uint16(h.FragmentOffset()), - TTL: Uint8(h.TTL()), - Protocol: Uint8(h.Protocol()), - Checksum: Uint16(h.Checksum()), - SrcAddr: Address(h.SourceAddress()), - DstAddr: Address(h.DestinationAddress()), - Options: &options, - } - var nextParser layerParser - // If it is a fragment, don't treat it as having a transport protocol. - if h.FragmentOffset() != 0 || h.More() { - return &ipv4, parsePayload - } - switch h.TransportProtocol() { - case header.TCPProtocolNumber: - nextParser = parseTCP - case header.UDPProtocolNumber: - nextParser = parseUDP - case header.ICMPv4ProtocolNumber: - nextParser = parseICMPv4 - default: - // Assume that the rest is a payload. - nextParser = parsePayload - } - return &ipv4, nextParser -} - -func (l *IPv4) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *IPv4) length() int { - if l.IHL == nil { - return header.IPv4MinimumSize - } - return int(*l.IHL) -} - -// merge implements Layer.merge. -func (l *IPv4) merge(other Layer) error { - return mergeLayer(l, other) -} - -// IPv6 can construct and match an IPv6 encapsulation. -type IPv6 struct { - LayerBase - TrafficClass *uint8 - FlowLabel *uint32 - PayloadLength *uint16 - NextHeader *uint8 - HopLimit *uint8 - SrcAddr *tcpip.Address - DstAddr *tcpip.Address -} - -func (l *IPv6) String() string { - return stringLayer(l) -} - -// ToBytes implements Layer.ToBytes. -func (l *IPv6) ToBytes() ([]byte, error) { - b := make([]byte, header.IPv6MinimumSize) - h := header.IPv6(b) - fields := &header.IPv6Fields{ - HopLimit: 64, - } - if l.TrafficClass != nil { - fields.TrafficClass = *l.TrafficClass - } - if l.FlowLabel != nil { - fields.FlowLabel = *l.FlowLabel - } - if l.PayloadLength != nil { - fields.PayloadLength = *l.PayloadLength - } else { - for current := l.next(); current != nil; current = current.next() { - fields.PayloadLength += uint16(current.length()) - } - } - if l.NextHeader != nil { - fields.TransportProtocol = tcpip.TransportProtocolNumber(*l.NextHeader) - } else { - nh, err := nextHeaderByLayer(l.next()) - if err != nil { - return nil, err - } - fields.TransportProtocol = tcpip.TransportProtocolNumber(nh) - } - if l.HopLimit != nil { - fields.HopLimit = *l.HopLimit - } - if l.SrcAddr != nil { - fields.SrcAddr = *l.SrcAddr - } - if l.DstAddr != nil { - fields.DstAddr = *l.DstAddr - } - h.Encode(fields) - return h, nil -} - -// nextIPv6PayloadParser finds the corresponding parser for nextHeader. -func nextIPv6PayloadParser(nextHeader uint8) layerParser { - switch tcpip.TransportProtocolNumber(nextHeader) { - case header.TCPProtocolNumber: - return parseTCP - case header.UDPProtocolNumber: - return parseUDP - case header.ICMPv6ProtocolNumber: - return parseICMPv6 - } - switch header.IPv6ExtensionHeaderIdentifier(nextHeader) { - case header.IPv6HopByHopOptionsExtHdrIdentifier: - return parseIPv6HopByHopOptionsExtHdr - case header.IPv6DestinationOptionsExtHdrIdentifier: - return parseIPv6DestinationOptionsExtHdr - case header.IPv6FragmentExtHdrIdentifier: - return parseIPv6FragmentExtHdr - } - return parsePayload -} - -// parseIPv6 parses the bytes assuming that they start with an ipv6 header and -// continues parsing further encapsulations. -func parseIPv6(b []byte) (Layer, layerParser) { - h := header.IPv6(b) - tos, flowLabel := h.TOS() - ipv6 := IPv6{ - TrafficClass: &tos, - FlowLabel: &flowLabel, - PayloadLength: Uint16(h.PayloadLength()), - NextHeader: Uint8(h.NextHeader()), - HopLimit: Uint8(h.HopLimit()), - SrcAddr: Address(h.SourceAddress()), - DstAddr: Address(h.DestinationAddress()), - } - nextParser := nextIPv6PayloadParser(h.NextHeader()) - return &ipv6, nextParser -} - -func (l *IPv6) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *IPv6) length() int { - return header.IPv6MinimumSize -} - -// merge overrides the values in l with the values from other but only in fields -// where the value is not nil. -func (l *IPv6) merge(other Layer) error { - return mergeLayer(l, other) -} - -// IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions -// Extension Header. -type IPv6HopByHopOptionsExtHdr struct { - LayerBase - NextHeader *header.IPv6ExtensionHeaderIdentifier - Options []byte -} - -// IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions -// Extension Header. -type IPv6DestinationOptionsExtHdr struct { - LayerBase - NextHeader *header.IPv6ExtensionHeaderIdentifier - Options []byte -} - -// IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header. -type IPv6FragmentExtHdr struct { - LayerBase - NextHeader *header.IPv6ExtensionHeaderIdentifier - FragmentOffset *uint16 - MoreFragments *bool - Identification *uint32 -} - -// nextHeaderByLayer finds the correct next header protocol value for layer l. -func nextHeaderByLayer(l Layer) (uint8, error) { - if l == nil { - return uint8(header.IPv6NoNextHeaderIdentifier), nil - } - switch l.(type) { - case *TCP: - return uint8(header.TCPProtocolNumber), nil - case *UDP: - return uint8(header.UDPProtocolNumber), nil - case *ICMPv6: - return uint8(header.ICMPv6ProtocolNumber), nil - case *Payload: - return uint8(header.IPv6NoNextHeaderIdentifier), nil - case *IPv6HopByHopOptionsExtHdr: - return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil - case *IPv6DestinationOptionsExtHdr: - return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil - case *IPv6FragmentExtHdr: - return uint8(header.IPv6FragmentExtHdrIdentifier), nil - default: - // TODO(b/161005083): Support more protocols as needed. - return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l) - } -} - -// ipv6OptionsExtHdrToBytes serializes an options extension header into bytes. -func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) { - length := len(options) + 2 - if length%8 != 0 { - return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options)) - } - bytes := make([]byte, length) - if nextHeader != nil { - bytes[0] = byte(*nextHeader) - } else { - nh, err := nextHeaderByLayer(nextLayer) - if err != nil { - return nil, err - } - bytes[0] = nh - } - // ExtHdrLen field is the length of the extension header - // in 8-octet unit, ignoring the first 8 octets. - // https://tools.ietf.org/html/rfc2460#section-4.3 - // https://tools.ietf.org/html/rfc2460#section-4.6 - bytes[1] = uint8((length - 8) / 8) - copy(bytes[2:], options) - return bytes, nil -} - -// IPv6ExtHdrIdent is a helper routine that allocates a new -// header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer -// to it. -func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier { - return &id -} - -// ToBytes implements Layer.ToBytes. -func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) { - return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) -} - -// ToBytes implements Layer.ToBytes. -func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) { - return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) -} - -// ToBytes implements Layer.ToBytes. -func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) { - var offset, mflag uint16 - var ident uint32 - bytes := make([]byte, header.IPv6FragmentExtHdrLength) - if l.NextHeader != nil { - bytes[0] = byte(*l.NextHeader) - } else { - nh, err := nextHeaderByLayer(l.next()) - if err != nil { - return nil, err - } - bytes[0] = nh - } - bytes[1] = 0 // reserved - if l.MoreFragments != nil && *l.MoreFragments { - mflag = 1 - } - if l.FragmentOffset != nil { - offset = *l.FragmentOffset - } - if l.Identification != nil { - ident = *l.Identification - } - offsetAndMflag := offset<<3 | mflag - binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag) - binary.BigEndian.PutUint32(bytes[4:], ident) - - return bytes, nil -} - -// parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader -// field, the rest of the payload and a parser function for the corresponding -// next extension header. -func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) { - nextHeader := b[0] - // For HopByHop and Destination options extension headers, - // This field is the length of the extension header in - // 8-octet units, not including the first 8 octets. - // https://tools.ietf.org/html/rfc2460#section-4.3 - // https://tools.ietf.org/html/rfc2460#section-4.6 - length := b[1]*8 + 8 - data := b[2:length] - nextParser := nextIPv6PayloadParser(nextHeader) - return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser -} - -// parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start -// with an IPv6 HopByHop Options Extension Header. -func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) { - nextHeader, options, nextParser := parseIPv6ExtHdr(b) - return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser -} - -// parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start -// with an IPv6 Destination Options Extension Header. -func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) { - nextHeader, options, nextParser := parseIPv6ExtHdr(b) - return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser -} - -// Bool is a helper routine that allocates a new -// bool value to store v and returns a pointer to it. -func Bool(v bool) *bool { - return &v -} - -// parseIPv6FragmentExtHdr parses the bytes assuming that they start -// with an IPv6 Fragment Extension Header. -func parseIPv6FragmentExtHdr(b []byte) (Layer, layerParser) { - nextHeader := b[0] - var extHdr header.IPv6FragmentExtHdr - copy(extHdr[:], b[2:]) - fragLayer := IPv6FragmentExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)), - FragmentOffset: Uint16(extHdr.FragmentOffset()), - MoreFragments: Bool(extHdr.More()), - Identification: Uint32(extHdr.ID()), - } - // If it is a fragment, we can't interpret it. - if extHdr.FragmentOffset() != 0 || extHdr.More() { - return &fragLayer, parsePayload - } - return &fragLayer, nextIPv6PayloadParser(nextHeader) -} - -func (l *IPv6HopByHopOptionsExtHdr) length() int { - return len(l.Options) + 2 -} - -func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool { - return equalLayer(l, other) -} - -// merge overrides the values in l with the values from other but only in fields -// where the value is not nil. -func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error { - return mergeLayer(l, other) -} - -func (l *IPv6HopByHopOptionsExtHdr) String() string { - return stringLayer(l) -} - -func (l *IPv6DestinationOptionsExtHdr) length() int { - return len(l.Options) + 2 -} - -func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool { - return equalLayer(l, other) -} - -// merge overrides the values in l with the values from other but only in fields -// where the value is not nil. -func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error { - return mergeLayer(l, other) -} - -func (l *IPv6DestinationOptionsExtHdr) String() string { - return stringLayer(l) -} - -func (*IPv6FragmentExtHdr) length() int { - return header.IPv6FragmentExtHdrLength -} - -func (l *IPv6FragmentExtHdr) match(other Layer) bool { - return equalLayer(l, other) -} - -// merge overrides the values in l with the values from other but only in fields -// where the value is not nil. -func (l *IPv6FragmentExtHdr) merge(other Layer) error { - return mergeLayer(l, other) -} - -func (l *IPv6FragmentExtHdr) String() string { - return stringLayer(l) -} - -// ICMPv6 can construct and match an ICMPv6 encapsulation. -type ICMPv6 struct { - LayerBase - Type *header.ICMPv6Type - Code *header.ICMPv6Code - Checksum *uint16 - Payload []byte -} - -func (l *ICMPv6) String() string { - // TODO(eyalsoha): Do something smarter here when *l.Type is ParameterProblem? - // We could parse the contents of the Payload as if it were an IPv6 packet. - return stringLayer(l) -} - -// ToBytes implements Layer.ToBytes. -func (l *ICMPv6) ToBytes() ([]byte, error) { - b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload)) - h := header.ICMPv6(b) - if l.Type != nil { - h.SetType(*l.Type) - } - if l.Code != nil { - h.SetCode(*l.Code) - } - if n := copy(h.MessageBody(), l.Payload); n != len(l.Payload) { - panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, len(l.Payload))) - } - if l.Checksum != nil { - h.SetChecksum(*l.Checksum) - } else { - // It is possible that the ICMPv6 header does not follow the IPv6 header - // immediately, there could be one or more extension headers in between. - // We need to search forward to find the IPv6 header. - for prev := l.Prev(); prev != nil; prev = prev.Prev() { - if ipv6, ok := prev.(*IPv6); ok { - payload, err := payload(l) - if err != nil { - return nil, err - } - h.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: h, - Src: *ipv6.SrcAddr, - Dst: *ipv6.DstAddr, - PayloadCsum: header.ChecksumVV(payload, 0 /* initial */), - PayloadLen: payload.Size(), - })) - break - } - } - } - return h, nil -} - -// ICMPv6Type is a helper routine that allocates a new ICMPv6Type value to store -// v and returns a pointer to it. -func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type { - return &v -} - -// ICMPv6Code is a helper routine that allocates a new ICMPv6Type value to store -// v and returns a pointer to it. -func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code { - return &v -} - -// Byte is a helper routine that allocates a new byte value to store -// v and returns a pointer to it. -func Byte(v byte) *byte { - return &v -} - -// parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header. -func parseICMPv6(b []byte) (Layer, layerParser) { - h := header.ICMPv6(b) - icmpv6 := ICMPv6{ - Type: ICMPv6Type(h.Type()), - Code: ICMPv6Code(h.Code()), - Checksum: Uint16(h.Checksum()), - Payload: h.MessageBody(), - } - return &icmpv6, nil -} - -func (l *ICMPv6) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *ICMPv6) length() int { - return header.ICMPv6HeaderSize + len(l.Payload) -} - -// merge overrides the values in l with the values from other but only in fields -// where the value is not nil. -func (l *ICMPv6) merge(other Layer) error { - return mergeLayer(l, other) -} - -// ICMPv4 can construct and match an ICMPv4 encapsulation. -type ICMPv4 struct { - LayerBase - Type *header.ICMPv4Type - Code *header.ICMPv4Code - Checksum *uint16 - Ident *uint16 // Only in Echo Request/Reply. - Sequence *uint16 // Only in Echo Request/Reply. - Pointer *uint8 // Only in Parameter Problem. - Payload []byte -} - -func (l *ICMPv4) String() string { - return stringLayer(l) -} - -// ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value -// to store t and returns a pointer to it. -func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type { - return &t -} - -// ICMPv4Code is a helper routine that allocates a new header.ICMPv4Code value -// to store t and returns a pointer to it. -func ICMPv4Code(t header.ICMPv4Code) *header.ICMPv4Code { - return &t -} - -// ToBytes implements Layer.ToBytes. -func (l *ICMPv4) ToBytes() ([]byte, error) { - b := make([]byte, header.ICMPv4MinimumSize+len(l.Payload)) - h := header.ICMPv4(b) - if l.Type != nil { - h.SetType(*l.Type) - } - if l.Code != nil { - h.SetCode(*l.Code) - } - if copied := copy(h.Payload(), l.Payload); copied != len(l.Payload) { - panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", len(h.Payload()), len(l.Payload))) - } - typ := h.Type() - switch typ { - case header.ICMPv4EchoReply, header.ICMPv4Echo: - if l.Ident != nil { - h.SetIdent(*l.Ident) - } - if l.Sequence != nil { - h.SetSequence(*l.Sequence) - } - case header.ICMPv4ParamProblem: - if l.Pointer != nil { - h.SetPointer(*l.Pointer) - } - } - - // The checksum must be handled last because the ICMPv4 header fields are - // included in the computation. - if l.Checksum != nil { - h.SetChecksum(*l.Checksum) - } else { - // Compute the checksum based on the ICMPv4.Payload and also the subsequent - // layers. - payload, err := payload(l) - if err != nil { - return nil, err - } - var vv buffer.VectorisedView - vv.AppendView(buffer.View(l.Payload)) - vv.Append(payload) - h.SetChecksum(header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0 /* initial */))) - } - - return h, nil -} - -// parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a -// parser for the encapsulated payload. -func parseICMPv4(b []byte) (Layer, layerParser) { - h := header.ICMPv4(b) - - msgType := h.Type() - icmpv4 := ICMPv4{ - Type: ICMPv4Type(msgType), - Code: ICMPv4Code(h.Code()), - Checksum: Uint16(h.Checksum()), - Payload: h.Payload(), - } - switch msgType { - case header.ICMPv4EchoReply, header.ICMPv4Echo: - icmpv4.Ident = Uint16(h.Ident()) - icmpv4.Sequence = Uint16(h.Sequence()) - case header.ICMPv4ParamProblem: - icmpv4.Pointer = Uint8(h.Pointer()) - } - return &icmpv4, nil -} - -func (l *ICMPv4) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *ICMPv4) length() int { - return header.ICMPv4MinimumSize -} - -// merge overrides the values in l with the values from other but only in fields -// where the value is not nil. -func (l *ICMPv4) merge(other Layer) error { - return mergeLayer(l, other) -} - -// TCP can construct and match a TCP encapsulation. -type TCP struct { - LayerBase - SrcPort *uint16 - DstPort *uint16 - SeqNum *uint32 - AckNum *uint32 - DataOffset *uint8 - Flags *uint8 - WindowSize *uint16 - Checksum *uint16 - UrgentPointer *uint16 - Options []byte -} - -func (l *TCP) String() string { - return stringLayer(l) -} - -// ToBytes implements Layer.ToBytes. -func (l *TCP) ToBytes() ([]byte, error) { - b := make([]byte, l.length()) - h := header.TCP(b) - if l.SrcPort != nil { - h.SetSourcePort(*l.SrcPort) - } - if l.DstPort != nil { - h.SetDestinationPort(*l.DstPort) - } - if l.SeqNum != nil { - h.SetSequenceNumber(*l.SeqNum) - } - if l.AckNum != nil { - h.SetAckNumber(*l.AckNum) - } - if l.DataOffset != nil { - h.SetDataOffset(*l.DataOffset) - } else { - h.SetDataOffset(uint8(l.length())) - } - if l.Flags != nil { - h.SetFlags(*l.Flags) - } - if l.WindowSize != nil { - h.SetWindowSize(*l.WindowSize) - } else { - h.SetWindowSize(32768) - } - if l.UrgentPointer != nil { - h.SetUrgentPoiner(*l.UrgentPointer) - } - copy(b[header.TCPMinimumSize:], l.Options) - header.AddTCPOptionPadding(b[header.TCPMinimumSize:], len(l.Options)) - if l.Checksum != nil { - h.SetChecksum(*l.Checksum) - return h, nil - } - if err := setTCPChecksum(&h, l); err != nil { - return nil, err - } - return h, nil -} - -// totalLength returns the length of the provided layer and all following -// layers. -func totalLength(l Layer) int { - var totalLength int - for ; l != nil; l = l.next() { - totalLength += l.length() - } - return totalLength -} - -// payload returns a buffer.VectorisedView of l's payload. -func payload(l Layer) (buffer.VectorisedView, error) { - var payloadBytes buffer.VectorisedView - for current := l.next(); current != nil; current = current.next() { - payload, err := current.ToBytes() - if err != nil { - return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload) - } - payloadBytes.AppendView(payload) - } - return payloadBytes, nil -} - -// layerChecksum calculates the checksum of the Layer header, including the -// peusdeochecksum of the layer before it and all the bytes after it. -func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { - totalLength := uint16(totalLength(l)) - var xsum uint16 - switch p := l.Prev().(type) { - case *IPv4: - xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) - case *IPv6: - xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) - default: - // TODO(b/161246171): Support more protocols. - return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p) - } - payloadBytes, err := payload(l) - if err != nil { - return 0, err - } - xsum = header.ChecksumVV(payloadBytes, xsum) - return xsum, nil -} - -// setTCPChecksum calculates the checksum of the TCP header and sets it in h. -func setTCPChecksum(h *header.TCP, tcp *TCP) error { - h.SetChecksum(0) - xsum, err := layerChecksum(tcp, header.TCPProtocolNumber) - if err != nil { - return err - } - h.SetChecksum(^h.CalculateChecksum(xsum)) - return nil -} - -// Uint32 is a helper routine that allocates a new -// uint32 value to store v and returns a pointer to it. -func Uint32(v uint32) *uint32 { - return &v -} - -// parseTCP parses the bytes assuming that they start with a tcp header and -// continues parsing further encapsulations. -func parseTCP(b []byte) (Layer, layerParser) { - h := header.TCP(b) - tcp := TCP{ - SrcPort: Uint16(h.SourcePort()), - DstPort: Uint16(h.DestinationPort()), - SeqNum: Uint32(h.SequenceNumber()), - AckNum: Uint32(h.AckNumber()), - DataOffset: Uint8(h.DataOffset()), - Flags: Uint8(h.Flags()), - WindowSize: Uint16(h.WindowSize()), - Checksum: Uint16(h.Checksum()), - UrgentPointer: Uint16(h.UrgentPointer()), - Options: b[header.TCPMinimumSize:h.DataOffset()], - } - return &tcp, parsePayload -} - -func (l *TCP) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *TCP) length() int { - if l.DataOffset == nil { - // TCP header including the options must end on a 32-bit - // boundary; the user could potentially give us a slice - // whose length is not a multiple of 4 bytes, so we have - // to do the alignment here. - optlen := (len(l.Options) + 3) & ^3 - return header.TCPMinimumSize + optlen - } - return int(*l.DataOffset) -} - -// merge implements Layer.merge. -func (l *TCP) merge(other Layer) error { - return mergeLayer(l, other) -} - -// UDP can construct and match a UDP encapsulation. -type UDP struct { - LayerBase - SrcPort *uint16 - DstPort *uint16 - Length *uint16 - Checksum *uint16 -} - -func (l *UDP) String() string { - return stringLayer(l) -} - -// ToBytes implements Layer.ToBytes. -func (l *UDP) ToBytes() ([]byte, error) { - b := make([]byte, header.UDPMinimumSize) - h := header.UDP(b) - if l.SrcPort != nil { - h.SetSourcePort(*l.SrcPort) - } - if l.DstPort != nil { - h.SetDestinationPort(*l.DstPort) - } - if l.Length != nil { - h.SetLength(*l.Length) - } else { - h.SetLength(uint16(totalLength(l))) - } - if l.Checksum != nil { - h.SetChecksum(*l.Checksum) - return h, nil - } - if err := setUDPChecksum(&h, l); err != nil { - return nil, err - } - return h, nil -} - -// setUDPChecksum calculates the checksum of the UDP header and sets it in h. -func setUDPChecksum(h *header.UDP, udp *UDP) error { - h.SetChecksum(0) - xsum, err := layerChecksum(udp, header.UDPProtocolNumber) - if err != nil { - return err - } - h.SetChecksum(^h.CalculateChecksum(xsum)) - return nil -} - -// parseUDP parses the bytes assuming that they start with a udp header and -// returns the parsed layer and the next parser to use. -func parseUDP(b []byte) (Layer, layerParser) { - h := header.UDP(b) - udp := UDP{ - SrcPort: Uint16(h.SourcePort()), - DstPort: Uint16(h.DestinationPort()), - Length: Uint16(h.Length()), - Checksum: Uint16(h.Checksum()), - } - return &udp, parsePayload -} - -func (l *UDP) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *UDP) length() int { - return header.UDPMinimumSize -} - -// merge implements Layer.merge. -func (l *UDP) merge(other Layer) error { - return mergeLayer(l, other) -} - -// Payload has bytes beyond OSI layer 4. -type Payload struct { - LayerBase - Bytes []byte -} - -func (l *Payload) String() string { - return stringLayer(l) -} - -// parsePayload parses the bytes assuming that they start with a payload and -// continue to the end. There can be no further encapsulations. -func parsePayload(b []byte) (Layer, layerParser) { - payload := Payload{ - Bytes: b, - } - return &payload, nil -} - -// ToBytes implements Layer.ToBytes. -func (l *Payload) ToBytes() ([]byte, error) { - return l.Bytes, nil -} - -// Length returns payload byte length. -func (l *Payload) Length() int { - return l.length() -} - -func (l *Payload) match(other Layer) bool { - return equalLayer(l, other) -} - -func (l *Payload) length() int { - return len(l.Bytes) -} - -// merge implements Layer.merge. -func (l *Payload) merge(other Layer) error { - return mergeLayer(l, other) -} - -// Layers is an array of Layer and supports similar functions to Layer. -type Layers []Layer - -// linkLayers sets the linked-list ponters in ls. -func (ls *Layers) linkLayers() { - for i, l := range *ls { - if i > 0 { - l.setPrev((*ls)[i-1]) - } else { - l.setPrev(nil) - } - if i+1 < len(*ls) { - l.setNext((*ls)[i+1]) - } else { - l.setNext(nil) - } - } -} - -// ToBytes converts the Layers into bytes. It creates a linked list of the Layer -// structs and then concatentates the output of ToBytes on each Layer. -func (ls *Layers) ToBytes() ([]byte, error) { - ls.linkLayers() - outBytes := []byte{} - for _, l := range *ls { - layerBytes, err := l.ToBytes() - if err != nil { - return nil, err - } - outBytes = append(outBytes, layerBytes...) - } - return outBytes, nil -} - -func (ls *Layers) match(other Layers) bool { - if len(*ls) > len(other) { - return false - } - for i, l := range *ls { - if !equalLayer(l, other[i]) { - return false - } - } - return true -} - -// layerDiff stores the diffs for each field along with the label for the Layer. -// If rows is nil, that means that there was no diff. -type layerDiff struct { - label string - rows []layerDiffRow -} - -// layerDiffRow stores the fields and corresponding values for two got and want -// layers. If the value was nil then the string stored is the empty string. -type layerDiffRow struct { - field, got, want string -} - -// diffLayer extracts all differing fields between two layers. -func diffLayer(got, want Layer) []layerDiffRow { - vGot := reflect.ValueOf(got).Elem() - vWant := reflect.ValueOf(want).Elem() - if vGot.Type() != vWant.Type() { - return nil - } - t := vGot.Type() - var result []layerDiffRow - for i := 0; i < t.NumField(); i++ { - t := t.Field(i) - if t.Anonymous { - // Ignore the LayerBase in the Layer struct. - continue - } - vGot := vGot.Field(i) - vWant := vWant.Field(i) - gotString := "" - if !vGot.IsNil() { - gotString = fmt.Sprint(reflect.Indirect(vGot)) - } - wantString := "" - if !vWant.IsNil() { - wantString = fmt.Sprint(reflect.Indirect(vWant)) - } - result = append(result, layerDiffRow{t.Name, gotString, wantString}) - } - return result -} - -// layerType returns a concise string describing the type of the Layer, like -// "TCP", or "IPv6". -func layerType(l Layer) string { - return reflect.TypeOf(l).Elem().Name() -} - -// diff compares Layers and returns a representation of the difference. Each -// Layer in the Layers is pairwise compared. If an element in either is nil, it -// is considered a match with the other Layer. If two Layers have differing -// types, they don't match regardless of the contents. If two Layers have the -// same type then the fields in the Layer are pairwise compared. Fields that are -// nil always match. Two non-nil fields only match if they point to equal -// values. diff returns an empty string if and only if *ls and other match. -func (ls *Layers) diff(other Layers) string { - var allDiffs []layerDiff - // Check the cases where one list is longer than the other, where one or both - // elements are nil, where the sides have different types, and where the sides - // have the same type. - for i := 0; i < len(*ls) || i < len(other); i++ { - if i >= len(*ls) { - // Matching ls against other where other is longer than ls. missing - // matches everything so we just include a label without any rows. Having - // no rows is a sign that there was no diff. - allDiffs = append(allDiffs, layerDiff{ - label: "missing matches " + layerType(other[i]), - }) - continue - } - - if i >= len(other) { - // Matching ls against other where ls is longer than other. missing - // matches everything so we just include a label without any rows. Having - // no rows is a sign that there was no diff. - allDiffs = append(allDiffs, layerDiff{ - label: layerType((*ls)[i]) + " matches missing", - }) - continue - } - - if (*ls)[i] == nil && other[i] == nil { - // Matching ls against other where both elements are nil. nil matches - // everything so we just include a label without any rows. Having no rows - // is a sign that there was no diff. - allDiffs = append(allDiffs, layerDiff{ - label: "nil matches nil", - }) - continue - } - - if (*ls)[i] == nil { - // Matching ls against other where the element in ls is nil. nil matches - // everything so we just include a label without any rows. Having no rows - // is a sign that there was no diff. - allDiffs = append(allDiffs, layerDiff{ - label: "nil matches " + layerType(other[i]), - }) - continue - } - - if other[i] == nil { - // Matching ls against other where the element in other is nil. nil - // matches everything so we just include a label without any rows. Having - // no rows is a sign that there was no diff. - allDiffs = append(allDiffs, layerDiff{ - label: layerType((*ls)[i]) + " matches nil", - }) - continue - } - - if reflect.TypeOf((*ls)[i]) == reflect.TypeOf(other[i]) { - // Matching ls against other where both elements have the same type. Match - // each field pairwise and only report a diff if there is a mismatch, - // which is only when both sides are non-nil and have differring values. - diff := diffLayer((*ls)[i], other[i]) - var layerDiffRows []layerDiffRow - for _, d := range diff { - if d.got == "" || d.want == "" || d.got == d.want { - continue - } - layerDiffRows = append(layerDiffRows, layerDiffRow{ - d.field, - d.got, - d.want, - }) - } - if len(layerDiffRows) > 0 { - allDiffs = append(allDiffs, layerDiff{ - label: layerType((*ls)[i]), - rows: layerDiffRows, - }) - } else { - allDiffs = append(allDiffs, layerDiff{ - label: layerType((*ls)[i]) + " matches " + layerType(other[i]), - // Having no rows is a sign that there was no diff. - }) - } - continue - } - // Neither side is nil and the types are different, so we'll display one - // side then the other. - allDiffs = append(allDiffs, layerDiff{ - label: layerType((*ls)[i]) + " doesn't match " + layerType(other[i]), - }) - diff := diffLayer((*ls)[i], (*ls)[i]) - layerDiffRows := []layerDiffRow{} - for _, d := range diff { - if len(d.got) == 0 { - continue - } - layerDiffRows = append(layerDiffRows, layerDiffRow{ - d.field, - d.got, - "", - }) - } - allDiffs = append(allDiffs, layerDiff{ - label: layerType((*ls)[i]), - rows: layerDiffRows, - }) - - layerDiffRows = []layerDiffRow{} - diff = diffLayer(other[i], other[i]) - for _, d := range diff { - if len(d.want) == 0 { - continue - } - layerDiffRows = append(layerDiffRows, layerDiffRow{ - d.field, - "", - d.want, - }) - } - allDiffs = append(allDiffs, layerDiff{ - label: layerType(other[i]), - rows: layerDiffRows, - }) - } - - output := "" - // These are for output formatting. - maxLabelLen, maxFieldLen, maxGotLen, maxWantLen := 0, 0, 0, 0 - foundOne := false - for _, l := range allDiffs { - if len(l.label) > maxLabelLen && len(l.rows) > 0 { - maxLabelLen = len(l.label) - } - if l.rows != nil { - foundOne = true - } - for _, r := range l.rows { - if len(r.field) > maxFieldLen { - maxFieldLen = len(r.field) - } - if l := len(fmt.Sprint(r.got)); l > maxGotLen { - maxGotLen = l - } - if l := len(fmt.Sprint(r.want)); l > maxWantLen { - maxWantLen = l - } - } - } - if !foundOne { - return "" - } - for _, l := range allDiffs { - if len(l.rows) == 0 { - output += "(" + l.label + ")\n" - continue - } - for i, r := range l.rows { - var label string - if i == 0 { - label = l.label + ":" - } - output += fmt.Sprintf( - "%*s %*s %*v %*v\n", - maxLabelLen+1, label, - maxFieldLen+1, r.field+":", - maxGotLen, r.got, - maxWantLen, r.want, - ) - } - } - return output -} - -// merge merges the other Layers into ls. If the other Layers is longer, those -// additional Layer structs are added to ls. The errors from merging are -// collected and returned. -func (ls *Layers) merge(other Layers) error { - var errs error - for i, o := range other { - if i < len(*ls) { - errs = multierr.Combine(errs, (*ls)[i].merge(o)) - } else { - *ls = append(*ls, o) - } - } - return errs -} diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go deleted file mode 100644 index eca0780b5..000000000 --- a/test/packetimpact/testbench/layers_test.go +++ /dev/null @@ -1,728 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testbench - -import ( - "bytes" - "net" - "testing" - - "github.com/mohae/deepcopy" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestLayerMatch(t *testing.T) { - var nilPayload *Payload - noPayload := &Payload{} - emptyPayload := &Payload{Bytes: []byte{}} - fullPayload := &Payload{Bytes: []byte{1, 2, 3}} - emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}} - fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}} - for _, tt := range []struct { - a, b Layer - want bool - }{ - {nilPayload, nilPayload, true}, - {nilPayload, noPayload, true}, - {nilPayload, emptyPayload, true}, - {nilPayload, fullPayload, true}, - {noPayload, noPayload, true}, - {noPayload, emptyPayload, true}, - {noPayload, fullPayload, true}, - {emptyPayload, emptyPayload, true}, - {emptyPayload, fullPayload, false}, - {fullPayload, fullPayload, true}, - {emptyTCP, fullTCP, true}, - } { - if got := tt.a.match(tt.b); got != tt.want { - t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want) - } - if got := tt.b.match(tt.a); got != tt.want { - t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want) - } - } -} - -func TestLayerMergeMismatch(t *testing.T) { - tcp := &TCP{} - otherTCP := &TCP{} - ipv4 := &IPv4{} - ether := &Ether{} - for _, tt := range []struct { - a, b Layer - success bool - }{ - {tcp, tcp, true}, - {tcp, otherTCP, true}, - {tcp, ipv4, false}, - {tcp, ether, false}, - {tcp, nil, true}, - - {otherTCP, otherTCP, true}, - {otherTCP, ipv4, false}, - {otherTCP, ether, false}, - {otherTCP, nil, true}, - - {ipv4, ipv4, true}, - {ipv4, ether, false}, - {ipv4, nil, true}, - - {ether, ether, true}, - {ether, nil, true}, - } { - if err := tt.a.merge(tt.b); (err == nil) != tt.success { - t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.a, tt.b, err) - } - if tt.b != nil { - if err := tt.b.merge(tt.a); (err == nil) != tt.success { - t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.b, tt.a, err) - } - } - } -} - -func TestLayerMerge(t *testing.T) { - zero := Uint32(0) - one := Uint32(1) - two := Uint32(2) - empty := []byte{} - foo := []byte("foo") - bar := []byte("bar") - for _, tt := range []struct { - a, b Layer - want Layer - }{ - {&TCP{AckNum: nil}, &TCP{AckNum: nil}, &TCP{AckNum: nil}}, - {&TCP{AckNum: nil}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, - {&TCP{AckNum: nil}, &TCP{AckNum: one}, &TCP{AckNum: one}}, - {&TCP{AckNum: nil}, &TCP{AckNum: two}, &TCP{AckNum: two}}, - {&TCP{AckNum: nil}, nil, &TCP{AckNum: nil}}, - - {&TCP{AckNum: zero}, &TCP{AckNum: nil}, &TCP{AckNum: zero}}, - {&TCP{AckNum: zero}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, - {&TCP{AckNum: zero}, &TCP{AckNum: one}, &TCP{AckNum: one}}, - {&TCP{AckNum: zero}, &TCP{AckNum: two}, &TCP{AckNum: two}}, - {&TCP{AckNum: zero}, nil, &TCP{AckNum: zero}}, - - {&TCP{AckNum: one}, &TCP{AckNum: nil}, &TCP{AckNum: one}}, - {&TCP{AckNum: one}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, - {&TCP{AckNum: one}, &TCP{AckNum: one}, &TCP{AckNum: one}}, - {&TCP{AckNum: one}, &TCP{AckNum: two}, &TCP{AckNum: two}}, - {&TCP{AckNum: one}, nil, &TCP{AckNum: one}}, - - {&TCP{AckNum: two}, &TCP{AckNum: nil}, &TCP{AckNum: two}}, - {&TCP{AckNum: two}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, - {&TCP{AckNum: two}, &TCP{AckNum: one}, &TCP{AckNum: one}}, - {&TCP{AckNum: two}, &TCP{AckNum: two}, &TCP{AckNum: two}}, - {&TCP{AckNum: two}, nil, &TCP{AckNum: two}}, - - {&Payload{Bytes: nil}, &Payload{Bytes: nil}, &Payload{Bytes: nil}}, - {&Payload{Bytes: nil}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, - {&Payload{Bytes: nil}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, - {&Payload{Bytes: nil}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, - {&Payload{Bytes: nil}, nil, &Payload{Bytes: nil}}, - - {&Payload{Bytes: empty}, &Payload{Bytes: nil}, &Payload{Bytes: empty}}, - {&Payload{Bytes: empty}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, - {&Payload{Bytes: empty}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, - {&Payload{Bytes: empty}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, - {&Payload{Bytes: empty}, nil, &Payload{Bytes: empty}}, - - {&Payload{Bytes: foo}, &Payload{Bytes: nil}, &Payload{Bytes: foo}}, - {&Payload{Bytes: foo}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, - {&Payload{Bytes: foo}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, - {&Payload{Bytes: foo}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, - {&Payload{Bytes: foo}, nil, &Payload{Bytes: foo}}, - - {&Payload{Bytes: bar}, &Payload{Bytes: nil}, &Payload{Bytes: bar}}, - {&Payload{Bytes: bar}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, - {&Payload{Bytes: bar}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, - {&Payload{Bytes: bar}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, - {&Payload{Bytes: bar}, nil, &Payload{Bytes: bar}}, - } { - a := deepcopy.Copy(tt.a).(Layer) - if err := a.merge(tt.b); err != nil { - t.Errorf("%s.merge(%s) = %s, wanted nil", tt.a, tt.b, err) - continue - } - if a.String() != tt.want.String() { - t.Errorf("%s.merge(%s) merge result got %s, want %s", tt.a, tt.b, a, tt.want) - } - } -} - -func TestLayerStringFormat(t *testing.T) { - for _, tt := range []struct { - name string - l Layer - want string - }{ - { - name: "TCP", - l: &TCP{ - SrcPort: Uint16(34785), - DstPort: Uint16(47767), - SeqNum: Uint32(3452155723), - AckNum: Uint32(2596996163), - DataOffset: Uint8(5), - Flags: Uint8(20), - WindowSize: Uint16(64240), - Checksum: Uint16(0x2e2b), - }, - want: "&testbench.TCP{" + - "SrcPort:34785 " + - "DstPort:47767 " + - "SeqNum:3452155723 " + - "AckNum:2596996163 " + - "DataOffset:5 " + - "Flags:20 " + - "WindowSize:64240 " + - "Checksum:11819" + - "}", - }, - { - name: "UDP", - l: &UDP{ - SrcPort: Uint16(34785), - DstPort: Uint16(47767), - Length: Uint16(12), - }, - want: "&testbench.UDP{" + - "SrcPort:34785 " + - "DstPort:47767 " + - "Length:12" + - "}", - }, - { - name: "IPv4", - l: &IPv4{ - IHL: Uint8(5), - TOS: Uint8(0), - TotalLength: Uint16(44), - ID: Uint16(0), - Flags: Uint8(2), - FragmentOffset: Uint16(0), - TTL: Uint8(64), - Protocol: Uint8(6), - Checksum: Uint16(0x2e2b), - SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), - DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), - }, - want: "&testbench.IPv4{" + - "IHL:5 " + - "TOS:0 " + - "TotalLength:44 " + - "ID:0 " + - "Flags:2 " + - "FragmentOffset:0 " + - "TTL:64 " + - "Protocol:6 " + - "Checksum:11819 " + - "SrcAddr:197.34.63.10 " + - "DstAddr:197.34.63.20" + - "}", - }, - { - name: "Ether", - l: &Ether{ - SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), - DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), - Type: NetworkProtocolNumber(4), - }, - want: "&testbench.Ether{" + - "SrcAddr:02:42:c5:22:3f:0a " + - "DstAddr:02:42:c5:22:3f:14 " + - "Type:4" + - "}", - }, - { - name: "Payload", - l: &Payload{ - Bytes: []byte("Hooray for packetimpact."), - }, - want: "&testbench.Payload{Bytes:\n" + - "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" + - "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" + - "}", - }, - } { - t.Run(tt.name, func(t *testing.T) { - if got := tt.l.String(); got != tt.want { - t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) - } - }) - } -} - -func TestConnectionMatch(t *testing.T) { - conn := Connection{ - layerStates: []layerState{ðerState{}}, - } - protoNum0 := tcpip.NetworkProtocolNumber(0) - protoNum1 := tcpip.NetworkProtocolNumber(1) - for _, tt := range []struct { - description string - override, received Layers - wantMatch bool - }{ - { - description: "shorter override", - override: []Layer{&Ether{}}, - received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, - wantMatch: true, - }, - { - description: "longer override", - override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, - received: []Layer{&Ether{}}, - wantMatch: false, - }, - { - description: "ether layer mismatch", - override: []Layer{&Ether{Type: &protoNum0}}, - received: []Layer{&Ether{Type: &protoNum1}}, - wantMatch: false, - }, - { - description: "both nil", - override: nil, - received: nil, - wantMatch: false, - }, - { - description: "nil override", - override: nil, - received: []Layer{&Ether{}}, - wantMatch: true, - }, - } { - t.Run(tt.description, func(t *testing.T) { - if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch { - t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch) - } - }) - } -} - -func TestLayersDiff(t *testing.T) { - for _, tt := range []struct { - x, y Layers - want string - }{ - { - Layers{&Ether{Type: NetworkProtocolNumber(12)}, &TCP{DataOffset: Uint8(5), SeqNum: Uint32(5)}}, - Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, - "Ether: Type: 12 13\n" + - " TCP: SeqNum: 5 6\n" + - " DataOffset: 5 7\n", - }, - { - Layers{&Ether{Type: NetworkProtocolNumber(12)}, &UDP{SrcPort: Uint16(123)}}, - Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, - "Ether: Type: 12 13\n" + - "(UDP doesn't match TCP)\n" + - " UDP: SrcPort: 123 \n" + - " TCP: SeqNum: 6\n" + - " DataOffset: 7\n", - }, - { - Layers{&UDP{SrcPort: Uint16(123)}}, - Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, - "(UDP doesn't match Ether)\n" + - " UDP: SrcPort: 123 \n" + - "Ether: Type: 13\n" + - "(missing matches TCP)\n", - }, - { - Layers{nil, &UDP{SrcPort: Uint16(123)}}, - Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, - "(nil matches Ether)\n" + - "(UDP doesn't match TCP)\n" + - "UDP: SrcPort: 123 \n" + - "TCP: SeqNum: 6\n" + - " DataOffset: 7\n", - }, - { - Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(4)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, - Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(6)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, - "(Ether matches Ether)\n" + - "IPv4: IHL: 4 6\n" + - "(TCP matches TCP)\n", - }, - { - Layers{&Payload{Bytes: []byte("foo")}}, - Layers{&Payload{Bytes: []byte("bar")}}, - "Payload: Bytes: [102 111 111] [98 97 114]\n", - }, - { - Layers{&Payload{Bytes: []byte("")}}, - Layers{&Payload{}}, - "", - }, - { - Layers{&Payload{Bytes: []byte("")}}, - Layers{&Payload{Bytes: []byte("")}}, - "", - }, - { - Layers{&UDP{}}, - Layers{&TCP{}}, - "(UDP doesn't match TCP)\n" + - "(UDP)\n" + - "(TCP)\n", - }, - } { - if got := tt.x.diff(tt.y); got != tt.want { - t.Errorf("%s.diff(%s) = %q, want %q", tt.x, tt.y, got, tt.want) - } - if tt.x.match(tt.y) != (tt.x.diff(tt.y) == "") { - t.Errorf("match and diff of %s and %s disagree", tt.x, tt.y) - } - if tt.y.match(tt.x) != (tt.y.diff(tt.x) == "") { - t.Errorf("match and diff of %s and %s disagree", tt.y, tt.x) - } - } -} - -func TestTCPOptions(t *testing.T) { - for _, tt := range []struct { - description string - wantBytes []byte - wantLayers Layers - }{ - { - description: "without payload", - wantBytes: []byte{ - // IPv4 Header - 0x45, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06, - 0xf9, 0x77, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01, - // TCP Header - 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xf5, 0x1c, 0x00, 0x00, - // WindowScale Option - 0x03, 0x03, 0x02, - // NOP Option - 0x00, - }, - wantLayers: []Layer{ - &IPv4{ - IHL: Uint8(20), - TOS: Uint8(0), - TotalLength: Uint16(44), - ID: Uint16(1), - Flags: Uint8(0), - FragmentOffset: Uint16(0), - TTL: Uint8(64), - Protocol: Uint8(uint8(header.TCPProtocolNumber)), - Checksum: Uint16(0xf977), - SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())), - DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())), - }, - &TCP{ - SrcPort: Uint16(12345), - DstPort: Uint16(54321), - SeqNum: Uint32(0), - AckNum: Uint32(0), - Flags: Uint8(header.TCPFlagSyn), - WindowSize: Uint16(8192), - Checksum: Uint16(0xf51c), - UrgentPointer: Uint16(0), - Options: []byte{3, 3, 2, 0}, - }, - &Payload{Bytes: nil}, - }, - }, - { - description: "with payload", - wantBytes: []byte{ - // IPv4 header - 0x45, 0x00, 0x00, 0x37, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06, - 0xf9, 0x6c, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01, - // TCP header - 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xe5, 0x21, 0x00, 0x00, - // WindowScale Option - 0x03, 0x03, 0x02, - // NOP Option - 0x00, - // Payload: "Sample Data" - 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, - }, - wantLayers: []Layer{ - &IPv4{ - IHL: Uint8(20), - TOS: Uint8(0), - TotalLength: Uint16(55), - ID: Uint16(1), - Flags: Uint8(0), - FragmentOffset: Uint16(0), - TTL: Uint8(64), - Protocol: Uint8(uint8(header.TCPProtocolNumber)), - Checksum: Uint16(0xf96c), - SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())), - DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())), - }, - &TCP{ - SrcPort: Uint16(12345), - DstPort: Uint16(54321), - SeqNum: Uint32(0), - AckNum: Uint32(0), - Flags: Uint8(header.TCPFlagSyn), - WindowSize: Uint16(8192), - Checksum: Uint16(0xe521), - UrgentPointer: Uint16(0), - Options: []byte{3, 3, 2, 0}, - }, - &Payload{Bytes: []byte("Sample Data")}, - }, - }, - } { - t.Run(tt.description, func(t *testing.T) { - layers := parse(parseIPv4, tt.wantBytes) - if !layers.match(tt.wantLayers) { - t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) - } - gotBytes, err := layers.ToBytes() - if err != nil { - t.Fatalf("ToBytes() failed on %s: %s", &layers, err) - } - if !bytes.Equal(tt.wantBytes, gotBytes) { - t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes) - } - }) - } -} - -func TestIPv6ExtHdrOptions(t *testing.T) { - for _, tt := range []struct { - description string - wantBytes []byte - wantLayers Layers - }{ - { - description: "IPv6/HopByHop", - wantBytes: []byte{ - // IPv6 Header - 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x40, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, - // HopByHop Options - 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, - }, - wantLayers: []Layer{ - &IPv6{ - SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), - DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), - }, - &IPv6HopByHopOptionsExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), - Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, - }, - &Payload{ - Bytes: nil, - }, - }, - }, - { - description: "IPv6/HopByHop/Payload", - wantBytes: []byte{ - // IPv6 Header - 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x40, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, - // HopByHop Options - 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, - // Sample Data - 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, - }, - wantLayers: []Layer{ - &IPv6{ - SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), - DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), - }, - &IPv6HopByHopOptionsExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), - Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, - }, - &Payload{ - Bytes: []byte("Sample Data"), - }, - }, - }, - { - description: "IPv6/HopByHop/Destination/ICMPv6", - wantBytes: []byte{ - // IPv6 Header - 0x60, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x40, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, - // HopByHop Options - 0x3c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, - // Destination Options - 0x3a, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, - // ICMPv6 Param Problem - 0x04, 0x00, 0x5f, 0x98, 0x00, 0x00, 0x00, 0x06, - }, - wantLayers: []Layer{ - &IPv6{ - SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), - DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), - }, - &IPv6HopByHopOptionsExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6DestinationOptionsExtHdrIdentifier), - Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, - }, - &IPv6DestinationOptionsExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)), - Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, - }, - &ICMPv6{ - Type: ICMPv6Type(header.ICMPv6ParamProblem), - Code: ICMPv6Code(header.ICMPv6ErroneousHeader), - Checksum: Uint16(0x5f98), - Payload: []byte{0x00, 0x00, 0x00, 0x06}, - }, - }, - }, - { - description: "IPv6/HopByHop/Fragment", - wantBytes: []byte{ - // IPv6 Header - 0x60, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x40, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, - // HopByHop Options - 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, - // Fragment ExtHdr - 0x3b, 0x00, 0x03, 0x20, 0x00, 0x00, 0x00, 0x2a, - }, - wantLayers: []Layer{ - &IPv6{ - SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), - DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), - }, - &IPv6HopByHopOptionsExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), - Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, - }, - &IPv6FragmentExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), - FragmentOffset: Uint16(100), - MoreFragments: Bool(false), - Identification: Uint32(42), - }, - &Payload{ - Bytes: nil, - }, - }, - }, - { - description: "IPv6/DestOpt/Fragment/Payload", - wantBytes: []byte{ - // IPv6 Header - 0x60, 0x00, 0x00, 0x00, 0x00, 0x1b, 0x3c, 0x40, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, - // Destination Options - 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, - // Fragment ExtHdr - 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, - // Sample Data - 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, - }, - wantLayers: []Layer{ - &IPv6{ - SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), - DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), - }, - &IPv6DestinationOptionsExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), - Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, - }, - &IPv6FragmentExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), - FragmentOffset: Uint16(100), - MoreFragments: Bool(true), - Identification: Uint32(42), - }, - &Payload{ - Bytes: []byte("Sample Data"), - }, - }, - }, - { - description: "IPv6/Fragment/Payload", - wantBytes: []byte{ - // IPv6 Header - 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x2c, 0x40, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, - // Fragment ExtHdr - 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, - // Sample Data - 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, - }, - wantLayers: []Layer{ - &IPv6{ - SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), - DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), - }, - &IPv6FragmentExtHdr{ - NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), - FragmentOffset: Uint16(100), - MoreFragments: Bool(true), - Identification: Uint32(42), - }, - &Payload{ - Bytes: []byte("Sample Data"), - }, - }, - }, - } { - t.Run(tt.description, func(t *testing.T) { - layers := parse(parseIPv6, tt.wantBytes) - if !layers.match(tt.wantLayers) { - t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) - } - // Make sure we can generate correct next header values and checksums - for _, layer := range layers { - switch layer := layer.(type) { - case *IPv6HopByHopOptionsExtHdr: - layer.NextHeader = nil - case *IPv6DestinationOptionsExtHdr: - layer.NextHeader = nil - case *IPv6FragmentExtHdr: - layer.NextHeader = nil - case *ICMPv6: - layer.Checksum = nil - } - } - gotBytes, err := layers.ToBytes() - if err != nil { - t.Fatalf("ToBytes() failed on %s: %s", &layers, err) - } - if !bytes.Equal(tt.wantBytes, gotBytes) { - t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes) - } - }) - } -} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go deleted file mode 100644 index 1ac96626a..000000000 --- a/test/packetimpact/testbench/rawsockets.go +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testbench - -import ( - "encoding/binary" - "fmt" - "math" - "net" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/usermem" -) - -// Sniffer can sniff raw packets on the wire. -type Sniffer struct { - fd int -} - -func htons(x uint16) uint16 { - buf := [2]byte{} - binary.BigEndian.PutUint16(buf[:], x) - return usermem.ByteOrder.Uint16(buf[:]) -} - -// NewSniffer creates a Sniffer connected to *device. -func (n *DUTTestNet) NewSniffer(t *testing.T) (Sniffer, error) { - t.Helper() - - ifInfo, err := net.InterfaceByName(n.LocalDevName) - if err != nil { - return Sniffer{}, err - } - - var haddr [8]byte - copy(haddr[:], ifInfo.HardwareAddr) - sa := unix.SockaddrLinklayer{ - Protocol: htons(unix.ETH_P_ALL), - Ifindex: ifInfo.Index, - } - snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) - if err != nil { - return Sniffer{}, err - } - if err := unix.Bind(snifferFd, &sa); err != nil { - return Sniffer{}, err - } - if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil { - t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err) - } - if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil { - t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) - } - return Sniffer{ - fd: snifferFd, - }, nil -} - -// maxReadSize should be large enough for the maximum frame size in bytes. If a -// packet too large for the buffer arrives, the test will get a fatal error. -const maxReadSize int = 65536 - -// Recv tries to read one frame until the timeout is up. If the timeout given -// is 0, then no read attempt will be made. -func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte { - t.Helper() - - deadline := time.Now().Add(timeout) - for { - timeout = deadline.Sub(time.Now()) - if timeout <= 0 { - return nil - } - whole, frac := math.Modf(timeout.Seconds()) - tv := unix.Timeval{ - Sec: int64(whole), - Usec: int64(frac * float64(time.Second/time.Microsecond)), - } - // The following should never happen, but having this guard here is better - // than blocking indefinitely in the future. - if tv.Sec == 0 && tv.Usec == 0 { - t.Fatal("setting SO_RCVTIMEO to 0 means blocking indefinitely") - } - if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { - t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) - } - - buf := make([]byte, maxReadSize) - nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) - if err == unix.EINTR || err == unix.EAGAIN { - // There was a timeout. - continue - } - if err != nil { - t.Fatalf("can't read: %s", err) - } - if nread > maxReadSize { - t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize) - } - return buf[:nread] - } -} - -// Drain drains the Sniffer's socket receive buffer by receiving until there's -// nothing else to receive. -func (s *Sniffer) Drain(t *testing.T) { - t.Helper() - - flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) - if err != nil { - t.Fatalf("failed to get sniffer socket fd flags: %s", err) - } - nonBlockingFlags := flags | unix.O_NONBLOCK - if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil { - t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err) - } - for { - buf := make([]byte, maxReadSize) - _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) - if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { - break - } - } - if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { - t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err) - } -} - -// close the socket that Sniffer is using. -func (s *Sniffer) close() error { - if err := unix.Close(s.fd); err != nil { - return fmt.Errorf("can't close sniffer socket: %w", err) - } - s.fd = -1 - return nil -} - -// Injector can inject raw frames. -type Injector struct { - fd int -} - -// NewInjector creates a new injector on *device. -func (n *DUTTestNet) NewInjector(t *testing.T) (Injector, error) { - t.Helper() - - ifInfo, err := net.InterfaceByName(n.LocalDevName) - if err != nil { - return Injector{}, err - } - - var haddr [8]byte - copy(haddr[:], ifInfo.HardwareAddr) - sa := unix.SockaddrLinklayer{ - Protocol: htons(unix.ETH_P_IP), - Ifindex: ifInfo.Index, - Halen: uint8(len(ifInfo.HardwareAddr)), - Addr: haddr, - } - - injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) - if err != nil { - return Injector{}, err - } - if err := unix.Bind(injectFd, &sa); err != nil { - return Injector{}, err - } - return Injector{ - fd: injectFd, - }, nil -} - -// Send a raw frame. -func (i *Injector) Send(t *testing.T, b []byte) { - t.Helper() - - n, err := unix.Write(i.fd, b) - if err != nil { - t.Fatalf("can't write bytes of len %d: %s", len(b), err) - } - if n != len(b) { - t.Fatalf("got %d bytes written, want %d", n, len(b)) - } -} - -// close the underlying socket. -func (i *Injector) close() error { - if err := unix.Close(i.fd); err != nil { - return fmt.Errorf("can't close sniffer socket: %w", err) - } - i.fd = -1 - return nil -} diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go deleted file mode 100644 index a73c07e64..000000000 --- a/test/packetimpact/testbench/testbench.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testbench has utilities to send and receive packets, and also command -// the DUT to run POSIX functions. It is the packetimpact test API. -package testbench - -import ( - "encoding/json" - "flag" - "fmt" - "math/rand" - "net" - "testing" - "time" -) - -var ( - // Native indicates that the test is being run natively. - Native = false - // RPCKeepalive is the gRPC keepalive. - RPCKeepalive = 10 * time.Second - // RPCTimeout is the gRPC timeout. - RPCTimeout = 100 * time.Millisecond - - // dutInfosJSON is the json string that describes information about all the - // duts available to use. - dutInfosJSON string - // dutInfo is the pool among which the testbench can choose a DUT to work - // with. - dutInfo chan *DUTInfo -) - -// DUTInfo has both network and uname information about the DUT. -type DUTInfo struct { - Uname *DUTUname - Net *DUTTestNet -} - -// DUTUname contains information about the DUT from uname. -type DUTUname struct { - Machine string - KernelName string - KernelRelease string - KernelVersion string - OperatingSystem string -} - -// DUTTestNet describes the test network setup on dut and how the testbench -// should connect with an existing DUT. -type DUTTestNet struct { - // LocalMAC is the local MAC address on the test network. - LocalMAC net.HardwareAddr - // RemoteMAC is the DUT's MAC address on the test network. - RemoteMAC net.HardwareAddr - // LocalIPv4 is the local IPv4 address on the test network. - LocalIPv4 net.IP - // RemoteIPv4 is the DUT's IPv4 address on the test network. - RemoteIPv4 net.IP - // IPv4PrefixLength is the network prefix length of the IPv4 test network. - IPv4PrefixLength int - // LocalIPv6 is the local IPv6 address on the test network. - LocalIPv6 net.IP - // RemoteIPv6 is the DUT's IPv6 address on the test network. - RemoteIPv6 net.IP - // LocalDevID is the ID of the local interface on the test network. - LocalDevID uint32 - // RemoteDevID is the ID of the remote interface on the test network. - RemoteDevID uint32 - // LocalDevName is the device that testbench uses to inject traffic. - LocalDevName string - // RemoteDevName is the device name on the DUT, individual tests can - // use the name to construct tests. - RemoteDevName string - - // The following two fields on actually on the control network instead - // of the test network, including them for convenience. - - // POSIXServerIP is the POSIX server's IP address on the control network. - POSIXServerIP net.IP - // POSIXServerPort is the UDP port the POSIX server is bound to on the - // control network. - POSIXServerPort uint16 -} - -// registerFlags defines flags and associates them with the package-level -// exported variables above. It should be called by tests in their init -// functions. -func registerFlags(fs *flag.FlagSet) { - fs.BoolVar(&Native, "native", Native, "whether the test is running natively") - fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout") - fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive") - fs.StringVar(&dutInfosJSON, "dut_infos_json", dutInfosJSON, "json that describes the DUTs") -} - -// Initialize initializes the testbench, it parse the flags and sets up the -// pool of test networks for testbench's later use. -func Initialize(fs *flag.FlagSet) { - registerFlags(fs) - flag.Parse() - if err := loadDUTInfos(); err != nil { - panic(err) - } -} - -// loadDUTInfos loads available DUT test infos from the json file, it -// must be called after flag.Parse(). -func loadDUTInfos() error { - var dutInfos []DUTInfo - if err := json.Unmarshal([]byte(dutInfosJSON), &dutInfos); err != nil { - return fmt.Errorf("failed to unmarshal JSON: %w", err) - } - if got, want := len(dutInfos), 1; got < want { - return fmt.Errorf("got %d DUTs, the test requires at least %d DUTs", got, want) - } - // Using a buffered channel as semaphore - dutInfo = make(chan *DUTInfo, len(dutInfos)) - for i := range dutInfos { - dutInfos[i].Net.LocalIPv4 = dutInfos[i].Net.LocalIPv4.To4() - dutInfos[i].Net.RemoteIPv4 = dutInfos[i].Net.RemoteIPv4.To4() - dutInfo <- &dutInfos[i] - } - return nil -} - -// GenerateRandomPayload generates a random byte slice of the specified length, -// causing a fatal test failure if it is unable to do so. -func GenerateRandomPayload(t *testing.T, n int) []byte { - t.Helper() - buf := make([]byte, n) - if _, err := rand.Read(buf); err != nil { - t.Fatalf("rand.Read(buf) failed: %s", err) - } - return buf -} - -// getDUTInfo returns information about an available DUT from the pool. If no -// DUT is readily available, getDUTInfo blocks until one becomes available. -func getDUTInfo() *DUTInfo { - return <-dutInfo -} - -// release returns the DUTInfo back to the pool. -func (info *DUTInfo) release() { - dutInfo <- info -} |