summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact/testbench
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact/testbench')
-rw-r--r--test/packetimpact/testbench/BUILD6
-rw-r--r--test/packetimpact/testbench/connections.go40
-rw-r--r--test/packetimpact/testbench/dut.go2
-rw-r--r--test/packetimpact/testbench/layers.go44
-rw-r--r--test/packetimpact/testbench/layers_test.go159
5 files changed, 223 insertions, 28 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index b6a254882..3ceceb9d7 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -23,7 +23,6 @@ go_library(
"//test/packetimpact/proto:posix_server_go_proto",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
- "@com_github_imdario_mergo//:go_default_library",
"@com_github_mohae_deepcopy//:go_default_library",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//keepalive:go_default_library",
@@ -37,5 +36,8 @@ go_test(
size = "small",
srcs = ["layers_test.go"],
library = ":testbench",
- deps = ["//pkg/tcpip"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ ],
)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index f84fd8ba7..952a717e0 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -72,7 +72,8 @@ type layerState interface {
// incoming creates an expected Layer for comparing against a received Layer.
// Because the expectation can depend on values in the received Layer, it is
// an input to incoming. For example, the ACK number needs to be checked in a
- // TCP packet but only if the ACK flag is set in the received packet.
+ // TCP packet but only if the ACK flag is set in the received packet. The
+ // calles takes ownership of the returned Layer.
incoming(received Layer) Layer
// sent updates the layerState based on the Layer that was sent. The input is
@@ -124,6 +125,7 @@ func (s *etherState) outgoing() Layer {
return &s.out
}
+// incoming implements layerState.incoming.
func (s *etherState) incoming(Layer) Layer {
return deepcopy.Copy(&s.in).(Layer)
}
@@ -168,6 +170,7 @@ func (s *ipv4State) outgoing() Layer {
return &s.out
}
+// incoming implements layerState.incoming.
func (s *ipv4State) incoming(Layer) Layer {
return deepcopy.Copy(&s.in).(Layer)
}
@@ -234,6 +237,7 @@ func (s *tcpState) outgoing() Layer {
return &newOutgoing
}
+// incoming implements layerState.incoming.
func (s *tcpState) incoming(received Layer) Layer {
tcpReceived, ok := received.(*TCP)
if !ok {
@@ -328,6 +332,7 @@ func (s *udpState) outgoing() Layer {
return &s.out
}
+// incoming implements layerState.incoming.
func (s *udpState) incoming(Layer) Layer {
return deepcopy.Copy(&s.in).(Layer)
}
@@ -363,16 +368,33 @@ type Connection struct {
// reverse is never a match. override overrides the default matchers for each
// Layer.
func (conn *Connection) match(override, received Layers) bool {
- if len(received) < len(conn.layerStates) {
+ var layersToMatch int
+ if len(override) < len(conn.layerStates) {
+ layersToMatch = len(conn.layerStates)
+ } else {
+ layersToMatch = len(override)
+ }
+ if len(received) < layersToMatch {
return false
}
- for i, s := range conn.layerStates {
- toMatch := s.incoming(received[i])
- if toMatch == nil {
- return false
- }
- if i < len(override) {
- toMatch.merge(override[i])
+ for i := 0; i < layersToMatch; i++ {
+ var toMatch Layer
+ if i < len(conn.layerStates) {
+ s := conn.layerStates[i]
+ toMatch = s.incoming(received[i])
+ if toMatch == nil {
+ return false
+ }
+ if i < len(override) {
+ if err := toMatch.merge(override[i]); err != nil {
+ conn.t.Fatalf("failed to merge: %s", err)
+ }
+ }
+ } else {
+ toMatch = override[i]
+ if toMatch == nil {
+ conn.t.Fatalf("expect the overriding layers to be non-nil")
+ }
}
if !toMatch.match(received[i]) {
return false
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index 9335909c0..3f340c6bc 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -132,7 +132,7 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16)
copy(sa.Addr[:], addr.To16())
dut.Bind(fd, &sa)
} else {
- dut.t.Fatal("unknown ip addr type for remoteIP")
+ dut.t.Fatalf("unknown ip addr type for remoteIP")
}
sa := dut.GetSockName(fd)
var port int
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 5ce324f0d..01e99567d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -22,7 +22,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "github.com/imdario/mergo"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -111,13 +110,31 @@ func equalLayer(x, y Layer) bool {
return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
}
-// mergeLayer merges other in layer. Any non-nil value in other overrides the
-// corresponding value in layer. If other is nil, no action is performed.
-func mergeLayer(layer, other Layer) error {
- if other == nil {
+// mergeLayer merges y into x. Any fields for which y has a non-nil value, that
+// value overwrite the corresponding fields in x.
+func mergeLayer(x, y Layer) error {
+ if y == nil {
return nil
}
- return mergo.Merge(layer, other, mergo.WithOverride)
+ if reflect.TypeOf(x) != reflect.TypeOf(y) {
+ return fmt.Errorf("can't merge %T into %T", y, x)
+ }
+ vx := reflect.ValueOf(x).Elem()
+ vy := reflect.ValueOf(y).Elem()
+ t := vy.Type()
+ for i := 0; i < vy.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := vy.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ vx.Field(i).Set(v)
+ }
+ return nil
}
func stringLayer(l Layer) string {
@@ -243,8 +260,7 @@ func (l *Ether) length() int {
return header.EthernetMinimumSize
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *Ether) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -399,8 +415,7 @@ func (l *IPv4) length() int {
return int(*l.IHL)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *IPv4) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -544,8 +559,7 @@ func (l *TCP) length() int {
return int(*l.DataOffset)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *TCP) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -622,8 +636,7 @@ func (l *UDP) length() int {
return int(*l.Length)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *UDP) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -659,8 +672,7 @@ func (l *Payload) length() int {
return len(l.Bytes)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *Payload) merge(other Layer) error {
return mergeLayer(l, other)
}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
index b32efda93..f07ec5eb2 100644
--- a/test/packetimpact/testbench/layers_test.go
+++ b/test/packetimpact/testbench/layers_test.go
@@ -17,6 +17,7 @@ package testbench
import (
"testing"
+ "github.com/mohae/deepcopy"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -52,6 +53,114 @@ func TestLayerMatch(t *testing.T) {
}
}
+func TestLayerMergeMismatch(t *testing.T) {
+ tcp := &TCP{}
+ otherTCP := &TCP{}
+ ipv4 := &IPv4{}
+ ether := &Ether{}
+ for _, tt := range []struct {
+ a, b Layer
+ success bool
+ }{
+ {tcp, tcp, true},
+ {tcp, otherTCP, true},
+ {tcp, ipv4, false},
+ {tcp, ether, false},
+ {tcp, nil, true},
+
+ {otherTCP, otherTCP, true},
+ {otherTCP, ipv4, false},
+ {otherTCP, ether, false},
+ {otherTCP, nil, true},
+
+ {ipv4, ipv4, true},
+ {ipv4, ether, false},
+ {ipv4, nil, true},
+
+ {ether, ether, true},
+ {ether, nil, true},
+ } {
+ if err := tt.a.merge(tt.b); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.a, tt.b, err)
+ }
+ if tt.b != nil {
+ if err := tt.b.merge(tt.a); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.b, tt.a, err)
+ }
+ }
+ }
+}
+
+func TestLayerMerge(t *testing.T) {
+ zero := Uint32(0)
+ one := Uint32(1)
+ two := Uint32(2)
+ empty := []byte{}
+ foo := []byte("foo")
+ bar := []byte("bar")
+ for _, tt := range []struct {
+ a, b Layer
+ want Layer
+ }{
+ {&TCP{AckNum: nil}, &TCP{AckNum: nil}, &TCP{AckNum: nil}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: nil}, nil, &TCP{AckNum: nil}},
+
+ {&TCP{AckNum: zero}, &TCP{AckNum: nil}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: zero}, nil, &TCP{AckNum: zero}},
+
+ {&TCP{AckNum: one}, &TCP{AckNum: nil}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: one}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: one}, nil, &TCP{AckNum: one}},
+
+ {&TCP{AckNum: two}, &TCP{AckNum: nil}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: two}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: two}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, nil, &TCP{AckNum: two}},
+
+ {&Payload{Bytes: nil}, &Payload{Bytes: nil}, &Payload{Bytes: nil}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: nil}, nil, &Payload{Bytes: nil}},
+
+ {&Payload{Bytes: empty}, &Payload{Bytes: nil}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: empty}, nil, &Payload{Bytes: empty}},
+
+ {&Payload{Bytes: foo}, &Payload{Bytes: nil}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: foo}, nil, &Payload{Bytes: foo}},
+
+ {&Payload{Bytes: bar}, &Payload{Bytes: nil}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, nil, &Payload{Bytes: bar}},
+ } {
+ a := deepcopy.Copy(tt.a).(Layer)
+ if err := a.merge(tt.b); err != nil {
+ t.Errorf("%s.merge(%s) = %s, wanted nil", tt.a, tt.b, err)
+ continue
+ }
+ if a.String() != tt.want.String() {
+ t.Errorf("%s.merge(%s) merge result got %s, want %s", tt.a, tt.b, a, tt.want)
+ }
+ }
+}
+
func TestLayerStringFormat(t *testing.T) {
for _, tt := range []struct {
name string
@@ -154,3 +263,53 @@ func TestLayerStringFormat(t *testing.T) {
})
}
}
+
+func TestConnectionMatch(t *testing.T) {
+ conn := Connection{
+ layerStates: []layerState{&etherState{}},
+ }
+ protoNum0 := tcpip.NetworkProtocolNumber(0)
+ protoNum1 := tcpip.NetworkProtocolNumber(1)
+ for _, tt := range []struct {
+ description string
+ override, received Layers
+ wantMatch bool
+ }{
+ {
+ description: "shorter override",
+ override: []Layer{&Ether{}},
+ received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
+ wantMatch: true,
+ },
+ {
+ description: "longer override",
+ override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
+ received: []Layer{&Ether{}},
+ wantMatch: false,
+ },
+ {
+ description: "ether layer mismatch",
+ override: []Layer{&Ether{Type: &protoNum0}},
+ received: []Layer{&Ether{Type: &protoNum1}},
+ wantMatch: false,
+ },
+ {
+ description: "both nil",
+ override: nil,
+ received: nil,
+ wantMatch: false,
+ },
+ {
+ description: "nil override",
+ override: nil,
+ received: []Layer{&Ether{}},
+ wantMatch: true,
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch {
+ t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch)
+ }
+ })
+ }
+}