summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorEyal Soha <eyalsoha@google.com>2020-04-23 18:20:43 -0700
committergVisor bot <gvisor-bot@google.com>2020-04-23 18:24:31 -0700
commit79542417fe97a62ee86aa211ac559bcc5cac5e5e (patch)
tree9d60f6f6de50d368818034b7ddfaabad483f8ab6
parentf01f2132d8d3e551579cba9a1b942b4b70d83f21 (diff)
Fix Layer merge and add unit tests
mergo was improperly merging nil and empty strings PiperOrigin-RevId: 308170862
-rw-r--r--WORKSPACE7
-rw-r--r--test/packetimpact/testbench/BUILD6
-rw-r--r--test/packetimpact/testbench/layers.go44
-rw-r--r--test/packetimpact/testbench/layers_test.go109
4 files changed, 141 insertions, 25 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 3bf5cc9c1..c86e0fcdc 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -189,13 +189,6 @@ go_repository(
)
go_repository(
- name = "com_github_imdario_mergo",
- importpath = "github.com/imdario/mergo",
- sum = "h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ=",
- version = "v0.3.8",
-)
-
-go_repository(
name = "com_github_kr_pretty",
importpath = "github.com/kr/pretty",
sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=",
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index b6a254882..3ceceb9d7 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -23,7 +23,6 @@ go_library(
"//test/packetimpact/proto:posix_server_go_proto",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
- "@com_github_imdario_mergo//:go_default_library",
"@com_github_mohae_deepcopy//:go_default_library",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//keepalive:go_default_library",
@@ -37,5 +36,8 @@ go_test(
size = "small",
srcs = ["layers_test.go"],
library = ":testbench",
- deps = ["//pkg/tcpip"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ ],
)
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 5ce324f0d..01e99567d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -22,7 +22,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "github.com/imdario/mergo"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -111,13 +110,31 @@ func equalLayer(x, y Layer) bool {
return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
}
-// mergeLayer merges other in layer. Any non-nil value in other overrides the
-// corresponding value in layer. If other is nil, no action is performed.
-func mergeLayer(layer, other Layer) error {
- if other == nil {
+// mergeLayer merges y into x. Any fields for which y has a non-nil value, that
+// value overwrite the corresponding fields in x.
+func mergeLayer(x, y Layer) error {
+ if y == nil {
return nil
}
- return mergo.Merge(layer, other, mergo.WithOverride)
+ if reflect.TypeOf(x) != reflect.TypeOf(y) {
+ return fmt.Errorf("can't merge %T into %T", y, x)
+ }
+ vx := reflect.ValueOf(x).Elem()
+ vy := reflect.ValueOf(y).Elem()
+ t := vy.Type()
+ for i := 0; i < vy.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := vy.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ vx.Field(i).Set(v)
+ }
+ return nil
}
func stringLayer(l Layer) string {
@@ -243,8 +260,7 @@ func (l *Ether) length() int {
return header.EthernetMinimumSize
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *Ether) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -399,8 +415,7 @@ func (l *IPv4) length() int {
return int(*l.IHL)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *IPv4) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -544,8 +559,7 @@ func (l *TCP) length() int {
return int(*l.DataOffset)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *TCP) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -622,8 +636,7 @@ func (l *UDP) length() int {
return int(*l.Length)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *UDP) merge(other Layer) error {
return mergeLayer(l, other)
}
@@ -659,8 +672,7 @@ func (l *Payload) length() int {
return len(l.Bytes)
}
-// merge overrides the values in l with the values from other but only in fields
-// where the value is not nil.
+// merge implements Layer.merge.
func (l *Payload) merge(other Layer) error {
return mergeLayer(l, other)
}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
index c99cf6312..f07ec5eb2 100644
--- a/test/packetimpact/testbench/layers_test.go
+++ b/test/packetimpact/testbench/layers_test.go
@@ -17,6 +17,7 @@ package testbench
import (
"testing"
+ "github.com/mohae/deepcopy"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -52,6 +53,114 @@ func TestLayerMatch(t *testing.T) {
}
}
+func TestLayerMergeMismatch(t *testing.T) {
+ tcp := &TCP{}
+ otherTCP := &TCP{}
+ ipv4 := &IPv4{}
+ ether := &Ether{}
+ for _, tt := range []struct {
+ a, b Layer
+ success bool
+ }{
+ {tcp, tcp, true},
+ {tcp, otherTCP, true},
+ {tcp, ipv4, false},
+ {tcp, ether, false},
+ {tcp, nil, true},
+
+ {otherTCP, otherTCP, true},
+ {otherTCP, ipv4, false},
+ {otherTCP, ether, false},
+ {otherTCP, nil, true},
+
+ {ipv4, ipv4, true},
+ {ipv4, ether, false},
+ {ipv4, nil, true},
+
+ {ether, ether, true},
+ {ether, nil, true},
+ } {
+ if err := tt.a.merge(tt.b); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.a, tt.b, err)
+ }
+ if tt.b != nil {
+ if err := tt.b.merge(tt.a); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.b, tt.a, err)
+ }
+ }
+ }
+}
+
+func TestLayerMerge(t *testing.T) {
+ zero := Uint32(0)
+ one := Uint32(1)
+ two := Uint32(2)
+ empty := []byte{}
+ foo := []byte("foo")
+ bar := []byte("bar")
+ for _, tt := range []struct {
+ a, b Layer
+ want Layer
+ }{
+ {&TCP{AckNum: nil}, &TCP{AckNum: nil}, &TCP{AckNum: nil}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: nil}, nil, &TCP{AckNum: nil}},
+
+ {&TCP{AckNum: zero}, &TCP{AckNum: nil}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: zero}, nil, &TCP{AckNum: zero}},
+
+ {&TCP{AckNum: one}, &TCP{AckNum: nil}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: one}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: one}, nil, &TCP{AckNum: one}},
+
+ {&TCP{AckNum: two}, &TCP{AckNum: nil}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: two}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: two}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, nil, &TCP{AckNum: two}},
+
+ {&Payload{Bytes: nil}, &Payload{Bytes: nil}, &Payload{Bytes: nil}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: nil}, nil, &Payload{Bytes: nil}},
+
+ {&Payload{Bytes: empty}, &Payload{Bytes: nil}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: empty}, nil, &Payload{Bytes: empty}},
+
+ {&Payload{Bytes: foo}, &Payload{Bytes: nil}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: foo}, nil, &Payload{Bytes: foo}},
+
+ {&Payload{Bytes: bar}, &Payload{Bytes: nil}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, nil, &Payload{Bytes: bar}},
+ } {
+ a := deepcopy.Copy(tt.a).(Layer)
+ if err := a.merge(tt.b); err != nil {
+ t.Errorf("%s.merge(%s) = %s, wanted nil", tt.a, tt.b, err)
+ continue
+ }
+ if a.String() != tt.want.String() {
+ t.Errorf("%s.merge(%s) merge result got %s, want %s", tt.a, tt.b, a, tt.want)
+ }
+ }
+}
+
func TestLayerStringFormat(t *testing.T) {
for _, tt := range []struct {
name string