diff options
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r-- | test/packetimpact/testbench/BUILD | 1 | ||||
-rw-r--r-- | test/packetimpact/testbench/connections.go | 63 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers.go | 99 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers_test.go | 95 | ||||
-rw-r--r-- | test/packetimpact/testbench/rawsockets.go | 23 |
5 files changed, 212 insertions, 69 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 199823419..838a10ffe 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -36,4 +36,5 @@ go_test( size = "small", srcs = ["layers_test.go"], library = ":testbench", + deps = ["//pkg/tcpip"], ) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 579da59c3..2b8e2f005 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -21,6 +21,7 @@ import ( "fmt" "math/rand" "net" + "strings" "testing" "time" @@ -210,11 +211,7 @@ func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers { if b == nil { break } - layers, err := ParseEther(b) - if err != nil { - conn.t.Logf("can't parse frame: %s", err) - continue // Ignore packets that can't be parsed. - } + layers := Parse(ParseEther, b) if !conn.incoming.match(layers) { continue // Ignore packets that don't match the expected incoming. } @@ -231,21 +228,31 @@ func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers { return nil } +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} + // Expect a packet that matches the provided tcp within the timeout specified. // If it doesn't arrive in time, it returns nil. -func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP { +func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { // We cannot implement this directly using ExpectFrame as we cannot specify // the Payload part. deadline := time.Now().Add(timeout) + var allTCP []string for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil + var gotTCP *TCP + if timeout = time.Until(deadline); timeout > 0 { + gotTCP = conn.Recv(timeout) + } + if gotTCP == nil { + return nil, fmt.Errorf("got %d packets:\n%s", len(allTCP), strings.Join(allTCP, "\n")) } - gotTCP := conn.Recv(timeout) if tcp.match(gotTCP) { - return gotTCP + return gotTCP, nil } + allTCP = append(allTCP, gotTCP.String()) } } @@ -284,10 +291,11 @@ func (conn *TCPIPv4) Handshake() { conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if conn.SynAck == nil { - conn.t.Fatalf("didn't get synack during handshake") + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if synAck == nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) } + conn.SynAck = synAck // Send an ACK. conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) @@ -412,11 +420,7 @@ func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP { if b == nil { break } - layers, err := ParseEther(b) - if err != nil { - conn.t.Logf("can't parse frame: %s", err) - continue // Ignore packets that can't be parsed. - } + layers := Parse(ParseEther, b) if !conn.incoming.match(layers) { continue // Ignore packets that don't match the expected incoming. } @@ -425,21 +429,28 @@ func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP { return nil } +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain() { + conn.sniffer.Drain() +} + // Expect a packet that matches the provided udp within the timeout specified. // If it doesn't arrive in time, the test fails. -func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) *UDP { +func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) { deadline := time.Now().Add(timeout) + var allUDP []string for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil + var gotUDP *UDP + if timeout = time.Until(deadline); timeout > 0 { + gotUDP = conn.Recv(timeout) } - gotUDP := conn.Recv(timeout) if gotUDP == nil { - return nil + return nil, fmt.Errorf("got %d packets:\n%s", len(allUDP), strings.Join(allUDP, "\n")) } if udp.match(gotUDP) { - return gotUDP + return gotUDP, nil } + allUDP = append(allUDP, gotUDP.String()) } } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 4d6625941..b467c15cc 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -118,7 +118,7 @@ func stringLayer(l Layer) string { if v.IsNil() { continue } - ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, reflect.Indirect(v))) } return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) } @@ -153,7 +153,7 @@ func (l *Ether) toBytes() ([]byte, error) { fields.Type = header.IPv4ProtocolNumber default: // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n) + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) } } h.Encode(fields) @@ -172,27 +172,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol return &v } +// LayerParser parses the input bytes and returns a Layer along with the next +// LayerParser to run. If there is no more parsing to do, the returned +// LayerParser is nil. +type LayerParser func([]byte) (Layer, LayerParser) + +// Parse parses bytes starting with the first LayerParser and using successive +// LayerParsers until all the bytes are parsed. +func Parse(parser LayerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + // ParseEther parses the bytes assuming that they start with an ethernet header // and continues parsing further encapsulations. -func ParseEther(b []byte) (Layers, error) { +func ParseEther(b []byte) (Layer, LayerParser) { h := header.Ethernet(b) ether := Ether{ SrcAddr: LinkAddress(h.SourceAddress()), DstAddr: LinkAddress(h.DestinationAddress()), Type: NetworkProtocolNumber(h.Type()), } - layers := Layers{ðer} + var nextParser LayerParser switch h.Type() { case header.IPv4ProtocolNumber: - moreLayers, err := ParseIPv4(b[ether.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = ParseIPv4 default: - // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %#v", b) + // Assume that the rest is a payload. + nextParser = ParsePayload } + return ðer, nextParser } func (l *Ether) match(other Layer) bool { @@ -274,7 +293,7 @@ func (l *IPv4) toBytes() ([]byte, error) { fields.Protocol = uint8(header.UDPProtocolNumber) default: // TODO(b/150301488): Support more protocols as needed. - return nil, fmt.Errorf("can't deduce the ip header's next protocol: %#v", n) + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) } } if l.SrcAddr != nil { @@ -313,7 +332,7 @@ func Address(v tcpip.Address) *tcpip.Address { // ParseIPv4 parses the bytes assuming that they start with an ipv4 header and // continues parsing further encapsulations. -func ParseIPv4(b []byte) (Layers, error) { +func ParseIPv4(b []byte) (Layer, LayerParser) { h := header.IPv4(b) tos, _ := h.TOS() ipv4 := IPv4{ @@ -329,22 +348,17 @@ func ParseIPv4(b []byte) (Layers, error) { SrcAddr: Address(h.SourceAddress()), DstAddr: Address(h.DestinationAddress()), } - layers := Layers{&ipv4} + var nextParser LayerParser switch h.TransportProtocol() { case header.TCPProtocolNumber: - moreLayers, err := ParseTCP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = ParseTCP case header.UDPProtocolNumber: - moreLayers, err := ParseUDP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = ParseUDP + default: + // Assume that the rest is a payload. + nextParser = ParsePayload } - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol()) + return &ipv4, nextParser } func (l *IPv4) match(other Layer) bool { @@ -470,7 +484,7 @@ func Uint32(v uint32) *uint32 { // ParseTCP parses the bytes assuming that they start with a tcp header and // continues parsing further encapsulations. -func ParseTCP(b []byte) (Layers, error) { +func ParseTCP(b []byte) (Layer, LayerParser) { h := header.TCP(b) tcp := TCP{ SrcPort: Uint16(h.SourcePort()), @@ -483,12 +497,7 @@ func ParseTCP(b []byte) (Layers, error) { Checksum: Uint16(h.Checksum()), UrgentPointer: Uint16(h.UrgentPointer()), } - layers := Layers{&tcp} - moreLayers, err := ParsePayload(b[tcp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &tcp, ParsePayload } func (l *TCP) match(other Layer) bool { @@ -557,8 +566,8 @@ func setUDPChecksum(h *header.UDP, udp *UDP) error { } // ParseUDP parses the bytes assuming that they start with a udp header and -// continues parsing further encapsulations. -func ParseUDP(b []byte) (Layers, error) { +// returns the parsed layer and the next parser to use. +func ParseUDP(b []byte) (Layer, LayerParser) { h := header.UDP(b) udp := UDP{ SrcPort: Uint16(h.SourcePort()), @@ -566,12 +575,7 @@ func ParseUDP(b []byte) (Layers, error) { Length: Uint16(h.Length()), Checksum: Uint16(h.Checksum()), } - layers := Layers{&udp} - moreLayers, err := ParsePayload(b[udp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &udp, ParsePayload } func (l *UDP) match(other Layer) bool { @@ -603,11 +607,11 @@ func (l *Payload) String() string { // ParsePayload parses the bytes assuming that they start with a payload and // continue to the end. There can be no further encapsulations. -func ParsePayload(b []byte) (Layers, error) { +func ParsePayload(b []byte) (Layer, LayerParser) { payload := Payload{ Bytes: b, } - return Layers{&payload}, nil + return &payload, nil } func (l *Payload) toBytes() ([]byte, error) { @@ -625,15 +629,24 @@ func (l *Payload) length() int { // Layers is an array of Layer and supports similar functions to Layer. type Layers []Layer -func (ls *Layers) toBytes() ([]byte, error) { +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { for i, l := range *ls { if i > 0 { l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) } if i+1 < len(*ls) { l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) } } +} + +func (ls *Layers) toBytes() ([]byte, error) { + ls.linkLayers() outBytes := []byte{} for _, l := range *ls { layerBytes, err := l.toBytes() diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index b39839625..8ffc26bf9 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -16,6 +16,8 @@ package testbench import "testing" +import "gvisor.dev/gvisor/pkg/tcpip" + func TestLayerMatch(t *testing.T) { var nilPayload *Payload noPayload := &Payload{} @@ -47,3 +49,96 @@ func TestLayerMatch(t *testing.T) { } } } + +func TestLayerStringFormat(t *testing.T) { + for _, tt := range []struct { + name string + l Layer + want string + }{ + { + name: "TCP", + l: &TCP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + SeqNum: Uint32(3452155723), + AckNum: Uint32(2596996163), + DataOffset: Uint8(5), + Flags: Uint8(20), + WindowSize: Uint16(64240), + Checksum: Uint16(0x2e2b), + }, + want: "&testbench.TCP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "SeqNum:3452155723 " + + "AckNum:2596996163 " + + "DataOffset:5 " + + "Flags:20 " + + "WindowSize:64240 " + + "Checksum:11819" + + "}", + }, + { + name: "UDP", + l: &UDP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + Length: Uint16(12), + }, + want: "&testbench.UDP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "Length:12" + + "}", + }, + { + name: "IPv4", + l: &IPv4{ + IHL: Uint8(5), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(0), + Flags: Uint8(2), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(6), + Checksum: Uint16(0x2e2b), + SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), + DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), + }, + want: "&testbench.IPv4{" + + "IHL:5 " + + "TOS:0 " + + "TotalLength:44 " + + "ID:0 " + + "Flags:2 " + + "FragmentOffset:0 " + + "TTL:64 " + + "Protocol:6 " + + "Checksum:11819 " + + "SrcAddr:197.34.63.10 " + + "DstAddr:197.34.63.20" + + "}", + }, + { + name: "Ether", + l: &Ether{ + SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), + DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), + Type: NetworkProtocolNumber(4), + }, + want: "&testbench.Ether{" + + "SrcAddr:02:42:c5:22:3f:0a " + + "DstAddr:02:42:c5:22:3f:14 " + + "Type:4" + + "}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.String(); got != tt.want { + t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 0074484f7..09bfa43c5 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -97,6 +97,29 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { } } +// Drain drains the Sniffer's socket receive buffer by receiving until there's +// nothing else to receive. +func (s *Sniffer) Drain() { + s.t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) + if err != nil { + s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { + s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + } + for { + buf := make([]byte, maxReadSize) + _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { + s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + } +} + // Close the socket that Sniffer is using. func (s *Sniffer) Close() { if err := unix.Close(s.fd); err != nil { |