summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact/testbench
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r--test/packetimpact/testbench/BUILD2
-rw-r--r--test/packetimpact/testbench/connections.go54
-rw-r--r--test/packetimpact/testbench/dut.go37
-rw-r--r--test/packetimpact/testbench/layers.go71
-rw-r--r--test/packetimpact/testbench/layers_test.go2
-rw-r--r--test/packetimpact/testbench/rawsockets.go20
-rw-r--r--test/packetimpact/testbench/testbench.go23
7 files changed, 141 insertions, 68 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 616215dc3..d8059ab98 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -16,6 +16,8 @@ go_library(
],
visibility = ["//test/packetimpact:__subpackages__"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
"//pkg/hostarch",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 8ad9040ff..ed56f9ac7 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -594,32 +594,50 @@ func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration)
func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) {
t.Helper()
- deadline := time.Now().Add(timeout)
+ frames, ok := conn.ListenForFrame(t, layers, timeout)
+ if ok {
+ return frames[len(frames)-1], nil
+ }
+ if len(frames) == 0 {
+ return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout)
+ }
+
var errs error
+ for _, got := range frames {
+ want := conn.incoming(layers)
+ if err := want.merge(layers); err != nil {
+ errs = multierr.Combine(errs, err)
+ } else {
+ errs = multierr.Combine(errs, &layersError{got: got, want: want})
+ }
+ }
+ return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout)
+}
+
+// ListenForFrame captures all frames until a frame matches the provided Layers,
+// or until the timeout specified. Returns all captured frames, including the
+// matched frame, and true if the desired frame was found.
+func (conn *Connection) ListenForFrame(t *testing.T, layers Layers, timeout time.Duration) ([]Layers, bool) {
+ t.Helper()
+
+ deadline := time.Now().Add(timeout)
+ var frames []Layers
for {
- var gotLayers Layers
+ var got Layers
if timeout := time.Until(deadline); timeout > 0 {
- gotLayers = conn.recvFrame(t, timeout)
+ got = conn.recvFrame(t, timeout)
}
- if gotLayers == nil {
- if errs == nil {
- return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout)
- }
- return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout)
+ if got == nil {
+ return frames, false
}
- if conn.match(layers, gotLayers) {
+ frames = append(frames, got)
+ if conn.match(layers, got) {
for i, s := range conn.layerStates {
- if err := s.received(gotLayers[i]); err != nil {
+ if err := s.received(got[i]); err != nil {
t.Fatalf("failed to update test connection's layer states based on received frame: %s", err)
}
}
- return gotLayers, nil
- }
- want := conn.incoming(layers)
- if err := want.merge(layers); err != nil {
- errs = multierr.Combine(errs, err)
- } else {
- errs = multierr.Combine(errs, &layersError{got: gotLayers, want: want})
+ return frames, true
}
}
}
@@ -1025,6 +1043,8 @@ func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...
// SendFrame sends a frame on the wire and updates the state of all layers.
func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
+ t.Helper()
+
conn.send(t, overrideLayers, additionalLayers...)
}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index eabdc8cb3..0cac0bf1b 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -22,11 +22,13 @@ import (
"testing"
"time"
- pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
-
"golang.org/x/sys/unix"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ bin "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
)
// DUT communicates with the DUT to force it to make POSIX calls.
@@ -428,6 +430,33 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, so
return ret, timeval, errno
}
+// GetSockOptTCPInfo retreives TCPInfo for the given socket descriptor.
+func (dut *DUT) GetSockOptTCPInfo(t *testing.T, sockfd int32) linux.TCPInfo {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, info, err := dut.GetSockOptTCPInfoWithErrno(ctx, t, sockfd)
+ if ret != 0 || err != unix.Errno(0) {
+ t.Fatalf("failed to GetSockOptTCPInfo: %s", err)
+ }
+ return info
+}
+
+// GetSockOptTCPInfoWithErrno retreives TCPInfo with any errno.
+func (dut *DUT) GetSockOptTCPInfoWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, linux.TCPInfo, error) {
+ t.Helper()
+
+ info := linux.TCPInfo{}
+ ret, infoBytes, errno := dut.GetSockOptWithErrno(ctx, t, sockfd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo))
+ if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want {
+ t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want)
+ }
+ bin.Unmarshal(infoBytes, hostarch.ByteOrder, &info)
+
+ return ret, info, errno
+}
+
// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// ListenWithErrno.
@@ -469,8 +498,8 @@ func (dut *DUT) PollOne(t *testing.T, fd int32, events int16, timeout time.Durat
if readyFd := pfds[0].Fd; readyFd != fd {
t.Fatalf("Poll returned an fd %d that was not requested (%d)", readyFd, fd)
}
- if got, want := pfds[0].Revents, int16(events); got&want == 0 {
- t.Fatalf("Poll returned no events in our interest, got: %#b, want: %#b", got, want)
+ if got, want := pfds[0].Revents, int16(events); got&want != want {
+ t.Fatalf("Poll returned events does not include all of the interested events, got: %#b, want: %#b", got, want)
}
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 2311f7686..2644b3248 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -357,7 +357,7 @@ func (l *IPv4) ToBytes() ([]byte, error) {
case *ICMPv4:
fields.Protocol = uint8(header.ICMPv4ProtocolNumber)
default:
- // TODO(b/150301488): Support more protocols as needed.
+ // We can add support for more protocols as needed.
return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
}
}
@@ -824,6 +824,8 @@ type ICMPv6 struct {
Type *header.ICMPv6Type
Code *header.ICMPv6Code
Checksum *uint16
+ Ident *uint16 // Only in Echo Request/Reply.
+ Pointer *uint32 // Only in Parameter Problem.
Payload []byte
}
@@ -835,7 +837,7 @@ func (l *ICMPv6) String() string {
// ToBytes implements Layer.ToBytes.
func (l *ICMPv6) ToBytes() ([]byte, error) {
- b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload))
+ b := make([]byte, header.ICMPv6MinimumSize+len(l.Payload))
h := header.ICMPv6(b)
if l.Type != nil {
h.SetType(*l.Type)
@@ -843,27 +845,34 @@ func (l *ICMPv6) ToBytes() ([]byte, error) {
if l.Code != nil {
h.SetCode(*l.Code)
}
- if n := copy(h.MessageBody(), l.Payload); n != len(l.Payload) {
+ if n := copy(h.Payload(), l.Payload); n != len(l.Payload) {
panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, len(l.Payload)))
}
+ typ := h.Type()
+ switch typ {
+ case header.ICMPv6EchoRequest, header.ICMPv6EchoReply:
+ if l.Ident != nil {
+ h.SetIdent(*l.Ident)
+ }
+ case header.ICMPv6ParamProblem:
+ if l.Pointer != nil {
+ h.SetTypeSpecific(*l.Pointer)
+ }
+ }
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
} else {
// It is possible that the ICMPv6 header does not follow the IPv6 header
// immediately, there could be one or more extension headers in between.
- // We need to search forward to find the IPv6 header.
- for prev := l.Prev(); prev != nil; prev = prev.Prev() {
- if ipv6, ok := prev.(*IPv6); ok {
- payload, err := payload(l)
- if err != nil {
- return nil, err
- }
+ // We need to search backwards to find the IPv6 header.
+ for layer := l.Prev(); layer != nil; layer = layer.Prev() {
+ if ipv6, ok := layer.(*IPv6); ok {
h.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: h,
+ Header: h[:header.ICMPv6PayloadOffset],
Src: *ipv6.SrcAddr,
Dst: *ipv6.DstAddr,
- PayloadCsum: header.ChecksumVV(payload, 0 /* initial */),
- PayloadLen: payload.Size(),
+ PayloadCsum: header.Checksum(l.Payload, 0 /* initial */),
+ PayloadLen: len(l.Payload),
}))
break
}
@@ -884,20 +893,21 @@ func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code {
return &v
}
-// Byte is a helper routine that allocates a new byte value to store
-// v and returns a pointer to it.
-func Byte(v byte) *byte {
- return &v
-}
-
// parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header.
func parseICMPv6(b []byte) (Layer, layerParser) {
h := header.ICMPv6(b)
+ msgType := h.Type()
icmpv6 := ICMPv6{
- Type: ICMPv6Type(h.Type()),
+ Type: ICMPv6Type(msgType),
Code: ICMPv6Code(h.Code()),
Checksum: Uint16(h.Checksum()),
- Payload: h.MessageBody(),
+ Payload: h.Payload(),
+ }
+ switch msgType {
+ case header.ICMPv6EchoRequest, header.ICMPv6EchoReply:
+ icmpv6.Ident = Uint16(h.Ident())
+ case header.ICMPv6ParamProblem:
+ icmpv6.Pointer = Uint32(h.TypeSpecific())
}
return &icmpv6, nil
}
@@ -907,7 +917,7 @@ func (l *ICMPv6) match(other Layer) bool {
}
func (l *ICMPv6) length() int {
- return header.ICMPv6HeaderSize + len(l.Payload)
+ return header.ICMPv6MinimumSize + len(l.Payload)
}
// merge overrides the values in l with the values from other but only in fields
@@ -954,8 +964,8 @@ func (l *ICMPv4) ToBytes() ([]byte, error) {
if l.Code != nil {
h.SetCode(*l.Code)
}
- if copied := copy(h.Payload(), l.Payload); copied != len(l.Payload) {
- panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", len(h.Payload()), len(l.Payload)))
+ if n := copy(h.Payload(), l.Payload); n != len(l.Payload) {
+ panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", n, len(l.Payload)))
}
typ := h.Type()
switch typ {
@@ -977,16 +987,7 @@ func (l *ICMPv4) ToBytes() ([]byte, error) {
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
} else {
- // Compute the checksum based on the ICMPv4.Payload and also the subsequent
- // layers.
- payload, err := payload(l)
- if err != nil {
- return nil, err
- }
- var vv buffer.VectorisedView
- vv.AppendView(buffer.View(l.Payload))
- vv.Append(payload)
- h.SetChecksum(header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0 /* initial */)))
+ h.SetChecksum(^header.Checksum(h, 0))
}
return h, nil
@@ -1019,7 +1020,7 @@ func (l *ICMPv4) match(other Layer) bool {
}
func (l *ICMPv4) length() int {
- return header.ICMPv4MinimumSize
+ return header.ICMPv4MinimumSize + len(l.Payload)
}
// merge overrides the values in l with the values from other but only in fields
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
index 614a5de1e..bc96e0c88 100644
--- a/test/packetimpact/testbench/layers_test.go
+++ b/test/packetimpact/testbench/layers_test.go
@@ -596,7 +596,7 @@ func TestIPv6ExtHdrOptions(t *testing.T) {
Type: ICMPv6Type(header.ICMPv6ParamProblem),
Code: ICMPv6Code(header.ICMPv6ErroneousHeader),
Checksum: Uint16(0x5f98),
- Payload: []byte{0x00, 0x00, 0x00, 0x06},
+ Pointer: Uint32(6),
},
},
},
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index feeb0888a..6d95c033d 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -17,7 +17,6 @@ package testbench
import (
"encoding/binary"
"fmt"
- "math"
"net"
"testing"
"time"
@@ -81,19 +80,20 @@ func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte {
deadline := time.Now().Add(timeout)
for {
- timeout = deadline.Sub(time.Now())
+ timeout = time.Until(deadline)
if timeout <= 0 {
return nil
}
- whole, frac := math.Modf(timeout.Seconds())
- tv := unix.Timeval{
- Sec: int64(whole),
- Usec: int64(frac * float64(time.Second/time.Microsecond)),
+ usec := timeout.Microseconds()
+ if usec == 0 {
+ // Timeout is less than a microsecond; set usec to 1 to avoid
+ // blocking indefinitely.
+ usec = 1
}
- // The following should never happen, but having this guard here is better
- // than blocking indefinitely in the future.
- if tv.Sec == 0 && tv.Usec == 0 {
- t.Fatal("setting SO_RCVTIMEO to 0 means blocking indefinitely")
+ const microsInOne = 1e6
+ tv := unix.Timeval{
+ Sec: usec / microsInOne,
+ Usec: usec % microsInOne,
}
if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
index 37d02365a..caa389780 100644
--- a/test/packetimpact/testbench/testbench.go
+++ b/test/packetimpact/testbench/testbench.go
@@ -57,11 +57,21 @@ type DUTUname struct {
OperatingSystem string
}
-// IsLinux returns true if we are running natively on Linux.
+// IsLinux returns true if the DUT is running Linux.
func (n *DUTUname) IsLinux() bool {
return Native && n.OperatingSystem == "GNU/Linux"
}
+// IsGvisor returns true if the DUT is running gVisor.
+func (*DUTUname) IsGvisor() bool {
+ return !Native
+}
+
+// IsFuchsia returns true if the DUT is running Fuchsia.
+func (n *DUTUname) IsFuchsia() bool {
+ return Native && n.OperatingSystem == "Fuchsia"
+}
+
// DUTTestNet describes the test network setup on dut and how the testbench
// should connect with an existing DUT.
type DUTTestNet struct {
@@ -99,6 +109,16 @@ type DUTTestNet struct {
POSIXServerPort uint16
}
+// SubnetBroadcast returns the test network's subnet broadcast address.
+func (n *DUTTestNet) SubnetBroadcast() net.IP {
+ addr := append([]byte(nil), n.RemoteIPv4...)
+ mask := net.CIDRMask(n.IPv4PrefixLength, net.IPv4len*8)
+ for i := range addr {
+ addr[i] |= ^mask[i]
+ }
+ return addr
+}
+
// registerFlags defines flags and associates them with the package-level
// exported variables above. It should be called by tests in their init
// functions.
@@ -112,6 +132,7 @@ func registerFlags(fs *flag.FlagSet) {
// Initialize initializes the testbench, it parse the flags and sets up the
// pool of test networks for testbench's later use.
func Initialize(fs *flag.FlagSet) {
+ testing.Init()
registerFlags(fs)
flag.Parse()
if err := loadDUTInfos(); err != nil {