diff options
-rw-r--r-- | test/packetimpact/testbench/BUILD | 9 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers.go | 61 | ||||
-rw-r--r-- | test/packetimpact/testbench/layers_test.go | 49 |
3 files changed, 110 insertions, 9 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 4a9d8efa6..199823419 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package( default_visibility = ["//test/packetimpact:__subpackages__"], @@ -30,3 +30,10 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "testbench_test", + size = "small", + srcs = ["layers_test.go"], + library = ":testbench", +) diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index d7434c3d2..4d6625941 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -17,6 +17,7 @@ package testbench import ( "fmt" "reflect" + "strings" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -32,6 +33,8 @@ import ( // Layer contains all the fields of the encapsulation. Each field is a pointer // and may be nil. type Layer interface { + fmt.Stringer + // toBytes converts the Layer into bytes. In places where the Layer's field // isn't nil, the value that is pointed to is used. When the field is nil, a // reasonable default for the Layer is used. For example, "64" for IPv4 TTL @@ -43,7 +46,8 @@ type Layer interface { // match checks if the current Layer matches the provided Layer. If either // Layer has a nil in a given field, that field is considered matching. - // Otherwise, the values pointed to by the fields must match. + // Otherwise, the values pointed to by the fields must match. The LayerBase is + // ignored. match(Layer) bool // length in bytes of the current encapsulation @@ -84,18 +88,39 @@ func (lb *LayerBase) setPrev(l Layer) { lb.prevLayer = l } +// equalLayer compares that two Layer structs match while ignoring field in +// which either input has a nil and also ignoring the LayerBase of the inputs. func equalLayer(x, y Layer) bool { + // opt ignores comparison pairs where either of the inputs is a nil. opt := cmp.FilterValues(func(x, y interface{}) bool { - if reflect.ValueOf(x).Kind() == reflect.Ptr && reflect.ValueOf(x).IsNil() { - return true - } - if reflect.ValueOf(y).Kind() == reflect.Ptr && reflect.ValueOf(y).IsNil() { - return true + for _, l := range []interface{}{x, y} { + v := reflect.ValueOf(l) + if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() { + return true + } } return false - }, cmp.Ignore()) - return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{})) + return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) +} + +func stringLayer(l Layer) string { + v := reflect.ValueOf(l).Elem() + t := v.Type() + var ret []string + for i := 0; i < v.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + v := v.Field(i) + if v.IsNil() { + continue + } + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } + return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) } // Ether can construct and match an ethernet encapsulation. @@ -106,6 +131,10 @@ type Ether struct { Type *tcpip.NetworkProtocolNumber } +func (l *Ether) String() string { + return stringLayer(l) +} + func (l *Ether) toBytes() ([]byte, error) { b := make([]byte, header.EthernetMinimumSize) h := header.Ethernet(b) @@ -190,6 +219,10 @@ type IPv4 struct { DstAddr *tcpip.Address } +func (l *IPv4) String() string { + return stringLayer(l) +} + func (l *IPv4) toBytes() ([]byte, error) { b := make([]byte, header.IPv4MinimumSize) h := header.IPv4(b) @@ -339,6 +372,10 @@ type TCP struct { UrgentPointer *uint16 } +func (l *TCP) String() string { + return stringLayer(l) +} + func (l *TCP) toBytes() ([]byte, error) { b := make([]byte, header.TCPMinimumSize) h := header.TCP(b) @@ -480,6 +517,10 @@ type UDP struct { Checksum *uint16 } +func (l *UDP) String() string { + return stringLayer(l) +} + func (l *UDP) toBytes() ([]byte, error) { b := make([]byte, header.UDPMinimumSize) h := header.UDP(b) @@ -556,6 +597,10 @@ type Payload struct { Bytes []byte } +func (l *Payload) String() string { + return stringLayer(l) +} + // ParsePayload parses the bytes assuming that they start with a payload and // continue to the end. There can be no further encapsulations. func ParsePayload(b []byte) (Layers, error) { diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go new file mode 100644 index 000000000..b39839625 --- /dev/null +++ b/test/packetimpact/testbench/layers_test.go @@ -0,0 +1,49 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import "testing" + +func TestLayerMatch(t *testing.T) { + var nilPayload *Payload + noPayload := &Payload{} + emptyPayload := &Payload{Bytes: []byte{}} + fullPayload := &Payload{Bytes: []byte{1, 2, 3}} + emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}} + fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}} + for _, tt := range []struct { + a, b Layer + want bool + }{ + {nilPayload, nilPayload, true}, + {nilPayload, noPayload, true}, + {nilPayload, emptyPayload, true}, + {nilPayload, fullPayload, true}, + {noPayload, noPayload, true}, + {noPayload, emptyPayload, true}, + {noPayload, fullPayload, true}, + {emptyPayload, emptyPayload, true}, + {emptyPayload, fullPayload, false}, + {fullPayload, fullPayload, true}, + {emptyTCP, fullTCP, true}, + } { + if got := tt.a.match(tt.b); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want) + } + if got := tt.b.match(tt.a); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want) + } + } +} |