diff options
Diffstat (limited to 'test/packetimpact/testbench/layers.go')
-rw-r--r-- | test/packetimpact/testbench/layers.go | 160 |
1 files changed, 106 insertions, 54 deletions
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 4d6625941..5ce324f0d 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -15,6 +15,7 @@ package testbench import ( + "encoding/hex" "fmt" "reflect" "strings" @@ -64,6 +65,9 @@ type Layer interface { // setPrev sets the pointer to the Layer encapsulating this one. setPrev(Layer) + + // merge overrides the values in the interface with the provided values. + merge(Layer) error } // LayerBase is the common elements of all layers. @@ -91,6 +95,9 @@ func (lb *LayerBase) setPrev(l Layer) { // equalLayer compares that two Layer structs match while ignoring field in // which either input has a nil and also ignoring the LayerBase of the inputs. func equalLayer(x, y Layer) bool { + if x == nil || y == nil { + return true + } // opt ignores comparison pairs where either of the inputs is a nil. opt := cmp.FilterValues(func(x, y interface{}) bool { for _, l := range []interface{}{x, y} { @@ -104,6 +111,15 @@ func equalLayer(x, y Layer) bool { return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) } +// mergeLayer merges other in layer. Any non-nil value in other overrides the +// corresponding value in layer. If other is nil, no action is performed. +func mergeLayer(layer, other Layer) error { + if other == nil { + return nil + } + return mergo.Merge(layer, other, mergo.WithOverride) +} + func stringLayer(l Layer) string { v := reflect.ValueOf(l).Elem() t := v.Type() @@ -118,7 +134,12 @@ func stringLayer(l Layer) string { if v.IsNil() { continue } - ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + v = reflect.Indirect(v) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) + } else { + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } } return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) } @@ -153,7 +174,7 @@ func (l *Ether) toBytes() ([]byte, error) { fields.Type = header.IPv4ProtocolNumber default: // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n) + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) } } h.Encode(fields) @@ -172,27 +193,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol return &v } -// ParseEther parses the bytes assuming that they start with an ethernet header +// layerParser parses the input bytes and returns a Layer along with the next +// layerParser to run. If there is no more parsing to do, the returned +// layerParser is nil. +type layerParser func([]byte) (Layer, layerParser) + +// parse parses bytes starting with the first layerParser and using successive +// layerParsers until all the bytes are parsed. +func parse(parser layerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + +// parseEther parses the bytes assuming that they start with an ethernet header // and continues parsing further encapsulations. -func ParseEther(b []byte) (Layers, error) { +func parseEther(b []byte) (Layer, layerParser) { h := header.Ethernet(b) ether := Ether{ SrcAddr: LinkAddress(h.SourceAddress()), DstAddr: LinkAddress(h.DestinationAddress()), Type: NetworkProtocolNumber(h.Type()), } - layers := Layers{ðer} + var nextParser layerParser switch h.Type() { case header.IPv4ProtocolNumber: - moreLayers, err := ParseIPv4(b[ether.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseIPv4 default: - // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %#v", b) + // Assume that the rest is a payload. + nextParser = parsePayload } + return ðer, nextParser } func (l *Ether) match(other Layer) bool { @@ -203,6 +243,12 @@ func (l *Ether) length() int { return header.EthernetMinimumSize } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Ether) merge(other Layer) error { + return mergeLayer(l, other) +} + // IPv4 can construct and match an IPv4 encapsulation. type IPv4 struct { LayerBase @@ -274,7 +320,7 @@ func (l *IPv4) toBytes() ([]byte, error) { fields.Protocol = uint8(header.UDPProtocolNumber) default: // TODO(b/150301488): Support more protocols as needed. - return nil, fmt.Errorf("can't deduce the ip header's next protocol: %#v", n) + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) } } if l.SrcAddr != nil { @@ -311,9 +357,9 @@ func Address(v tcpip.Address) *tcpip.Address { return &v } -// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and +// parseIPv4 parses the bytes assuming that they start with an ipv4 header and // continues parsing further encapsulations. -func ParseIPv4(b []byte) (Layers, error) { +func parseIPv4(b []byte) (Layer, layerParser) { h := header.IPv4(b) tos, _ := h.TOS() ipv4 := IPv4{ @@ -329,22 +375,17 @@ func ParseIPv4(b []byte) (Layers, error) { SrcAddr: Address(h.SourceAddress()), DstAddr: Address(h.DestinationAddress()), } - layers := Layers{&ipv4} + var nextParser layerParser switch h.TransportProtocol() { case header.TCPProtocolNumber: - moreLayers, err := ParseTCP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseTCP case header.UDPProtocolNumber: - moreLayers, err := ParseUDP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseUDP + default: + // Assume that the rest is a payload. + nextParser = parsePayload } - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol()) + return &ipv4, nextParser } func (l *IPv4) match(other Layer) bool { @@ -358,6 +399,12 @@ func (l *IPv4) length() int { return int(*l.IHL) } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + // TCP can construct and match a TCP encapsulation. type TCP struct { LayerBase @@ -468,9 +515,9 @@ func Uint32(v uint32) *uint32 { return &v } -// ParseTCP parses the bytes assuming that they start with a tcp header and +// parseTCP parses the bytes assuming that they start with a tcp header and // continues parsing further encapsulations. -func ParseTCP(b []byte) (Layers, error) { +func parseTCP(b []byte) (Layer, layerParser) { h := header.TCP(b) tcp := TCP{ SrcPort: Uint16(h.SourcePort()), @@ -483,12 +530,7 @@ func ParseTCP(b []byte) (Layers, error) { Checksum: Uint16(h.Checksum()), UrgentPointer: Uint16(h.UrgentPointer()), } - layers := Layers{&tcp} - moreLayers, err := ParsePayload(b[tcp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &tcp, parsePayload } func (l *TCP) match(other Layer) bool { @@ -504,8 +546,8 @@ func (l *TCP) length() int { // merge overrides the values in l with the values from other but only in fields // where the value is not nil. -func (l *TCP) merge(other TCP) error { - return mergo.Merge(l, other, mergo.WithOverride) +func (l *TCP) merge(other Layer) error { + return mergeLayer(l, other) } // UDP can construct and match a UDP encapsulation. @@ -556,9 +598,9 @@ func setUDPChecksum(h *header.UDP, udp *UDP) error { return nil } -// ParseUDP parses the bytes assuming that they start with a udp header and -// continues parsing further encapsulations. -func ParseUDP(b []byte) (Layers, error) { +// parseUDP parses the bytes assuming that they start with a udp header and +// returns the parsed layer and the next parser to use. +func parseUDP(b []byte) (Layer, layerParser) { h := header.UDP(b) udp := UDP{ SrcPort: Uint16(h.SourcePort()), @@ -566,12 +608,7 @@ func ParseUDP(b []byte) (Layers, error) { Length: Uint16(h.Length()), Checksum: Uint16(h.Checksum()), } - layers := Layers{&udp} - moreLayers, err := ParsePayload(b[udp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &udp, parsePayload } func (l *UDP) match(other Layer) bool { @@ -587,8 +624,8 @@ func (l *UDP) length() int { // merge overrides the values in l with the values from other but only in fields // where the value is not nil. -func (l *UDP) merge(other UDP) error { - return mergo.Merge(l, other, mergo.WithOverride) +func (l *UDP) merge(other Layer) error { + return mergeLayer(l, other) } // Payload has bytes beyond OSI layer 4. @@ -601,13 +638,13 @@ func (l *Payload) String() string { return stringLayer(l) } -// ParsePayload parses the bytes assuming that they start with a payload and +// parsePayload parses the bytes assuming that they start with a payload and // continue to the end. There can be no further encapsulations. -func ParsePayload(b []byte) (Layers, error) { +func parsePayload(b []byte) (Layer, layerParser) { payload := Payload{ Bytes: b, } - return Layers{&payload}, nil + return &payload, nil } func (l *Payload) toBytes() ([]byte, error) { @@ -622,18 +659,33 @@ func (l *Payload) length() int { return len(l.Bytes) } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Payload) merge(other Layer) error { + return mergeLayer(l, other) +} + // Layers is an array of Layer and supports similar functions to Layer. type Layers []Layer -func (ls *Layers) toBytes() ([]byte, error) { +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { for i, l := range *ls { if i > 0 { l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) } if i+1 < len(*ls) { l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) } } +} + +func (ls *Layers) toBytes() ([]byte, error) { + ls.linkLayers() outBytes := []byte{} for _, l := range *ls { layerBytes, err := l.toBytes() @@ -649,8 +701,8 @@ func (ls *Layers) match(other Layers) bool { if len(*ls) > len(other) { return false } - for i := 0; i < len(*ls); i++ { - if !equalLayer((*ls)[i], other[i]) { + for i, l := range *ls { + if !equalLayer(l, other[i]) { return false } } |