diff options
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r-- | test/packetimpact/testbench/BUILD | 2 | ||||
-rw-r--r-- | test/packetimpact/testbench/connections.go | 54 | ||||
-rw-r--r-- | test/packetimpact/testbench/dut.go | 37 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers.go | 71 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers_test.go | 2 | ||||
-rw-r--r-- | test/packetimpact/testbench/rawsockets.go | 20 | ||||
-rw-r--r-- | test/packetimpact/testbench/testbench.go | 23 |
7 files changed, 141 insertions, 68 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 616215dc3..d8059ab98 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -16,6 +16,8 @@ go_library( ], visibility = ["//test/packetimpact:__subpackages__"], deps = [ + "//pkg/abi/linux", + "//pkg/binary", "//pkg/hostarch", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 8ad9040ff..ed56f9ac7 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -594,32 +594,50 @@ func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) { t.Helper() - deadline := time.Now().Add(timeout) + frames, ok := conn.ListenForFrame(t, layers, timeout) + if ok { + return frames[len(frames)-1], nil + } + if len(frames) == 0 { + return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout) + } + var errs error + for _, got := range frames { + want := conn.incoming(layers) + if err := want.merge(layers); err != nil { + errs = multierr.Combine(errs, err) + } else { + errs = multierr.Combine(errs, &layersError{got: got, want: want}) + } + } + return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout) +} + +// ListenForFrame captures all frames until a frame matches the provided Layers, +// or until the timeout specified. Returns all captured frames, including the +// matched frame, and true if the desired frame was found. +func (conn *Connection) ListenForFrame(t *testing.T, layers Layers, timeout time.Duration) ([]Layers, bool) { + t.Helper() + + deadline := time.Now().Add(timeout) + var frames []Layers for { - var gotLayers Layers + var got Layers if timeout := time.Until(deadline); timeout > 0 { - gotLayers = conn.recvFrame(t, timeout) + got = conn.recvFrame(t, timeout) } - if gotLayers == nil { - if errs == nil { - return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout) - } - return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout) + if got == nil { + return frames, false } - if conn.match(layers, gotLayers) { + frames = append(frames, got) + if conn.match(layers, got) { for i, s := range conn.layerStates { - if err := s.received(gotLayers[i]); err != nil { + if err := s.received(got[i]); err != nil { t.Fatalf("failed to update test connection's layer states based on received frame: %s", err) } } - return gotLayers, nil - } - want := conn.incoming(layers) - if err := want.merge(layers); err != nil { - errs = multierr.Combine(errs, err) - } else { - errs = multierr.Combine(errs, &layersError{got: gotLayers, want: want}) + return frames, true } } } @@ -1025,6 +1043,8 @@ func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ... // SendFrame sends a frame on the wire and updates the state of all layers. func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + t.Helper() + conn.send(t, overrideLayers, additionalLayers...) } diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index eabdc8cb3..0cac0bf1b 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -22,11 +22,13 @@ import ( "testing" "time" - pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" - "golang.org/x/sys/unix" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "gvisor.dev/gvisor/pkg/abi/linux" + bin "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" ) // DUT communicates with the DUT to force it to make POSIX calls. @@ -428,6 +430,33 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, so return ret, timeval, errno } +// GetSockOptTCPInfo retreives TCPInfo for the given socket descriptor. +func (dut *DUT) GetSockOptTCPInfo(t *testing.T, sockfd int32) linux.TCPInfo { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, info, err := dut.GetSockOptTCPInfoWithErrno(ctx, t, sockfd) + if ret != 0 || err != unix.Errno(0) { + t.Fatalf("failed to GetSockOptTCPInfo: %s", err) + } + return info +} + +// GetSockOptTCPInfoWithErrno retreives TCPInfo with any errno. +func (dut *DUT) GetSockOptTCPInfoWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, linux.TCPInfo, error) { + t.Helper() + + info := linux.TCPInfo{} + ret, infoBytes, errno := dut.GetSockOptWithErrno(ctx, t, sockfd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } + bin.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + + return ret, info, errno +} + // Listen calls listen on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // ListenWithErrno. @@ -469,8 +498,8 @@ func (dut *DUT) PollOne(t *testing.T, fd int32, events int16, timeout time.Durat if readyFd := pfds[0].Fd; readyFd != fd { t.Fatalf("Poll returned an fd %d that was not requested (%d)", readyFd, fd) } - if got, want := pfds[0].Revents, int16(events); got&want == 0 { - t.Fatalf("Poll returned no events in our interest, got: %#b, want: %#b", got, want) + if got, want := pfds[0].Revents, int16(events); got&want != want { + t.Fatalf("Poll returned events does not include all of the interested events, got: %#b, want: %#b", got, want) } } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 2311f7686..2644b3248 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -357,7 +357,7 @@ func (l *IPv4) ToBytes() ([]byte, error) { case *ICMPv4: fields.Protocol = uint8(header.ICMPv4ProtocolNumber) default: - // TODO(b/150301488): Support more protocols as needed. + // We can add support for more protocols as needed. return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) } } @@ -824,6 +824,8 @@ type ICMPv6 struct { Type *header.ICMPv6Type Code *header.ICMPv6Code Checksum *uint16 + Ident *uint16 // Only in Echo Request/Reply. + Pointer *uint32 // Only in Parameter Problem. Payload []byte } @@ -835,7 +837,7 @@ func (l *ICMPv6) String() string { // ToBytes implements Layer.ToBytes. func (l *ICMPv6) ToBytes() ([]byte, error) { - b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload)) + b := make([]byte, header.ICMPv6MinimumSize+len(l.Payload)) h := header.ICMPv6(b) if l.Type != nil { h.SetType(*l.Type) @@ -843,27 +845,34 @@ func (l *ICMPv6) ToBytes() ([]byte, error) { if l.Code != nil { h.SetCode(*l.Code) } - if n := copy(h.MessageBody(), l.Payload); n != len(l.Payload) { + if n := copy(h.Payload(), l.Payload); n != len(l.Payload) { panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, len(l.Payload))) } + typ := h.Type() + switch typ { + case header.ICMPv6EchoRequest, header.ICMPv6EchoReply: + if l.Ident != nil { + h.SetIdent(*l.Ident) + } + case header.ICMPv6ParamProblem: + if l.Pointer != nil { + h.SetTypeSpecific(*l.Pointer) + } + } if l.Checksum != nil { h.SetChecksum(*l.Checksum) } else { // It is possible that the ICMPv6 header does not follow the IPv6 header // immediately, there could be one or more extension headers in between. - // We need to search forward to find the IPv6 header. - for prev := l.Prev(); prev != nil; prev = prev.Prev() { - if ipv6, ok := prev.(*IPv6); ok { - payload, err := payload(l) - if err != nil { - return nil, err - } + // We need to search backwards to find the IPv6 header. + for layer := l.Prev(); layer != nil; layer = layer.Prev() { + if ipv6, ok := layer.(*IPv6); ok { h.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: h, + Header: h[:header.ICMPv6PayloadOffset], Src: *ipv6.SrcAddr, Dst: *ipv6.DstAddr, - PayloadCsum: header.ChecksumVV(payload, 0 /* initial */), - PayloadLen: payload.Size(), + PayloadCsum: header.Checksum(l.Payload, 0 /* initial */), + PayloadLen: len(l.Payload), })) break } @@ -884,20 +893,21 @@ func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code { return &v } -// Byte is a helper routine that allocates a new byte value to store -// v and returns a pointer to it. -func Byte(v byte) *byte { - return &v -} - // parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header. func parseICMPv6(b []byte) (Layer, layerParser) { h := header.ICMPv6(b) + msgType := h.Type() icmpv6 := ICMPv6{ - Type: ICMPv6Type(h.Type()), + Type: ICMPv6Type(msgType), Code: ICMPv6Code(h.Code()), Checksum: Uint16(h.Checksum()), - Payload: h.MessageBody(), + Payload: h.Payload(), + } + switch msgType { + case header.ICMPv6EchoRequest, header.ICMPv6EchoReply: + icmpv6.Ident = Uint16(h.Ident()) + case header.ICMPv6ParamProblem: + icmpv6.Pointer = Uint32(h.TypeSpecific()) } return &icmpv6, nil } @@ -907,7 +917,7 @@ func (l *ICMPv6) match(other Layer) bool { } func (l *ICMPv6) length() int { - return header.ICMPv6HeaderSize + len(l.Payload) + return header.ICMPv6MinimumSize + len(l.Payload) } // merge overrides the values in l with the values from other but only in fields @@ -954,8 +964,8 @@ func (l *ICMPv4) ToBytes() ([]byte, error) { if l.Code != nil { h.SetCode(*l.Code) } - if copied := copy(h.Payload(), l.Payload); copied != len(l.Payload) { - panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", len(h.Payload()), len(l.Payload))) + if n := copy(h.Payload(), l.Payload); n != len(l.Payload) { + panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", n, len(l.Payload))) } typ := h.Type() switch typ { @@ -977,16 +987,7 @@ func (l *ICMPv4) ToBytes() ([]byte, error) { if l.Checksum != nil { h.SetChecksum(*l.Checksum) } else { - // Compute the checksum based on the ICMPv4.Payload and also the subsequent - // layers. - payload, err := payload(l) - if err != nil { - return nil, err - } - var vv buffer.VectorisedView - vv.AppendView(buffer.View(l.Payload)) - vv.Append(payload) - h.SetChecksum(header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0 /* initial */))) + h.SetChecksum(^header.Checksum(h, 0)) } return h, nil @@ -1019,7 +1020,7 @@ func (l *ICMPv4) match(other Layer) bool { } func (l *ICMPv4) length() int { - return header.ICMPv4MinimumSize + return header.ICMPv4MinimumSize + len(l.Payload) } // merge overrides the values in l with the values from other but only in fields diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index 614a5de1e..bc96e0c88 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -596,7 +596,7 @@ func TestIPv6ExtHdrOptions(t *testing.T) { Type: ICMPv6Type(header.ICMPv6ParamProblem), Code: ICMPv6Code(header.ICMPv6ErroneousHeader), Checksum: Uint16(0x5f98), - Payload: []byte{0x00, 0x00, 0x00, 0x06}, + Pointer: Uint32(6), }, }, }, diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index feeb0888a..6d95c033d 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -17,7 +17,6 @@ package testbench import ( "encoding/binary" "fmt" - "math" "net" "testing" "time" @@ -81,19 +80,20 @@ func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte { deadline := time.Now().Add(timeout) for { - timeout = deadline.Sub(time.Now()) + timeout = time.Until(deadline) if timeout <= 0 { return nil } - whole, frac := math.Modf(timeout.Seconds()) - tv := unix.Timeval{ - Sec: int64(whole), - Usec: int64(frac * float64(time.Second/time.Microsecond)), + usec := timeout.Microseconds() + if usec == 0 { + // Timeout is less than a microsecond; set usec to 1 to avoid + // blocking indefinitely. + usec = 1 } - // The following should never happen, but having this guard here is better - // than blocking indefinitely in the future. - if tv.Sec == 0 && tv.Usec == 0 { - t.Fatal("setting SO_RCVTIMEO to 0 means blocking indefinitely") + const microsInOne = 1e6 + tv := unix.Timeval{ + Sec: usec / microsInOne, + Usec: usec % microsInOne, } if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index 37d02365a..caa389780 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -57,11 +57,21 @@ type DUTUname struct { OperatingSystem string } -// IsLinux returns true if we are running natively on Linux. +// IsLinux returns true if the DUT is running Linux. func (n *DUTUname) IsLinux() bool { return Native && n.OperatingSystem == "GNU/Linux" } +// IsGvisor returns true if the DUT is running gVisor. +func (*DUTUname) IsGvisor() bool { + return !Native +} + +// IsFuchsia returns true if the DUT is running Fuchsia. +func (n *DUTUname) IsFuchsia() bool { + return Native && n.OperatingSystem == "Fuchsia" +} + // DUTTestNet describes the test network setup on dut and how the testbench // should connect with an existing DUT. type DUTTestNet struct { @@ -99,6 +109,16 @@ type DUTTestNet struct { POSIXServerPort uint16 } +// SubnetBroadcast returns the test network's subnet broadcast address. +func (n *DUTTestNet) SubnetBroadcast() net.IP { + addr := append([]byte(nil), n.RemoteIPv4...) + mask := net.CIDRMask(n.IPv4PrefixLength, net.IPv4len*8) + for i := range addr { + addr[i] |= ^mask[i] + } + return addr +} + // registerFlags defines flags and associates them with the package-level // exported variables above. It should be called by tests in their init // functions. @@ -112,6 +132,7 @@ func registerFlags(fs *flag.FlagSet) { // Initialize initializes the testbench, it parse the flags and sets up the // pool of test networks for testbench's later use. func Initialize(fs *flag.FlagSet) { + testing.Init() registerFlags(fs) flag.Parse() if err := loadDUTInfos(); err != nil { |