summaryrefslogtreecommitdiffhomepage
path: root/test/packetimpact/testbench/layers.go
diff options
context:
space:
mode:
Diffstat (limited to 'test/packetimpact/testbench/layers.go')
-rw-r--r--test/packetimpact/testbench/layers.go339
1 files changed, 271 insertions, 68 deletions
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 35fa4dcb6..5ce324f0d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -15,13 +15,16 @@
package testbench
import (
+ "encoding/hex"
"fmt"
"reflect"
+ "strings"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/imdario/mergo"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -31,6 +34,8 @@ import (
// Layer contains all the fields of the encapsulation. Each field is a pointer
// and may be nil.
type Layer interface {
+ fmt.Stringer
+
// toBytes converts the Layer into bytes. In places where the Layer's field
// isn't nil, the value that is pointed to is used. When the field is nil, a
// reasonable default for the Layer is used. For example, "64" for IPv4 TTL
@@ -42,7 +47,8 @@ type Layer interface {
// match checks if the current Layer matches the provided Layer. If either
// Layer has a nil in a given field, that field is considered matching.
- // Otherwise, the values pointed to by the fields must match.
+ // Otherwise, the values pointed to by the fields must match. The LayerBase is
+ // ignored.
match(Layer) bool
// length in bytes of the current encapsulation
@@ -59,6 +65,9 @@ type Layer interface {
// setPrev sets the pointer to the Layer encapsulating this one.
setPrev(Layer)
+
+ // merge overrides the values in the interface with the provided values.
+ merge(Layer) error
}
// LayerBase is the common elements of all layers.
@@ -83,21 +92,59 @@ func (lb *LayerBase) setPrev(l Layer) {
lb.prevLayer = l
}
+// equalLayer compares that two Layer structs match while ignoring field in
+// which either input has a nil and also ignoring the LayerBase of the inputs.
func equalLayer(x, y Layer) bool {
+ if x == nil || y == nil {
+ return true
+ }
+ // opt ignores comparison pairs where either of the inputs is a nil.
opt := cmp.FilterValues(func(x, y interface{}) bool {
- if reflect.ValueOf(x).Kind() == reflect.Ptr && reflect.ValueOf(x).IsNil() {
- return true
- }
- if reflect.ValueOf(y).Kind() == reflect.Ptr && reflect.ValueOf(y).IsNil() {
- return true
+ for _, l := range []interface{}{x, y} {
+ v := reflect.ValueOf(l)
+ if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
+ return true
+ }
}
return false
-
}, cmp.Ignore())
- return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{}))
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
+}
+
+// mergeLayer merges other in layer. Any non-nil value in other overrides the
+// corresponding value in layer. If other is nil, no action is performed.
+func mergeLayer(layer, other Layer) error {
+ if other == nil {
+ return nil
+ }
+ return mergo.Merge(layer, other, mergo.WithOverride)
+}
+
+func stringLayer(l Layer) string {
+ v := reflect.ValueOf(l).Elem()
+ t := v.Type()
+ var ret []string
+ for i := 0; i < v.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := v.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ v = reflect.Indirect(v)
+ if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
+ ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
+ } else {
+ ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
+ }
+ }
+ return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
}
-// Ether can construct and match the ethernet encapsulation.
+// Ether can construct and match an ethernet encapsulation.
type Ether struct {
LayerBase
SrcAddr *tcpip.LinkAddress
@@ -105,6 +152,10 @@ type Ether struct {
Type *tcpip.NetworkProtocolNumber
}
+func (l *Ether) String() string {
+ return stringLayer(l)
+}
+
func (l *Ether) toBytes() ([]byte, error) {
b := make([]byte, header.EthernetMinimumSize)
h := header.Ethernet(b)
@@ -123,7 +174,7 @@ func (l *Ether) toBytes() ([]byte, error) {
fields.Type = header.IPv4ProtocolNumber
default:
// TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n)
+ return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
}
}
h.Encode(fields)
@@ -142,27 +193,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol
return &v
}
-// ParseEther parses the bytes assuming that they start with an ethernet header
+// layerParser parses the input bytes and returns a Layer along with the next
+// layerParser to run. If there is no more parsing to do, the returned
+// layerParser is nil.
+type layerParser func([]byte) (Layer, layerParser)
+
+// parse parses bytes starting with the first layerParser and using successive
+// layerParsers until all the bytes are parsed.
+func parse(parser layerParser, b []byte) Layers {
+ var layers Layers
+ for {
+ var layer Layer
+ layer, parser = parser(b)
+ layers = append(layers, layer)
+ if parser == nil {
+ break
+ }
+ b = b[layer.length():]
+ }
+ layers.linkLayers()
+ return layers
+}
+
+// parseEther parses the bytes assuming that they start with an ethernet header
// and continues parsing further encapsulations.
-func ParseEther(b []byte) (Layers, error) {
+func parseEther(b []byte) (Layer, layerParser) {
h := header.Ethernet(b)
ether := Ether{
SrcAddr: LinkAddress(h.SourceAddress()),
DstAddr: LinkAddress(h.DestinationAddress()),
Type: NetworkProtocolNumber(h.Type()),
}
- layers := Layers{&ether}
+ var nextParser layerParser
switch h.Type() {
case header.IPv4ProtocolNumber:
- moreLayers, err := ParseIPv4(b[ether.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ nextParser = parseIPv4
default:
- // TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %v", b)
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
}
+ return &ether, nextParser
}
func (l *Ether) match(other Layer) bool {
@@ -173,7 +243,13 @@ func (l *Ether) length() int {
return header.EthernetMinimumSize
}
-// IPv4 can construct and match the ethernet excapulation.
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Ether) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv4 can construct and match an IPv4 encapsulation.
type IPv4 struct {
LayerBase
IHL *uint8
@@ -189,6 +265,10 @@ type IPv4 struct {
DstAddr *tcpip.Address
}
+func (l *IPv4) String() string {
+ return stringLayer(l)
+}
+
func (l *IPv4) toBytes() ([]byte, error) {
b := make([]byte, header.IPv4MinimumSize)
h := header.IPv4(b)
@@ -236,9 +316,11 @@ func (l *IPv4) toBytes() ([]byte, error) {
switch n := l.next().(type) {
case *TCP:
fields.Protocol = uint8(header.TCPProtocolNumber)
+ case *UDP:
+ fields.Protocol = uint8(header.UDPProtocolNumber)
default:
- // TODO(b/150301488): Support more protocols, like UDP.
- return nil, fmt.Errorf("can't deduce the ip header's next protocol: %+v", n)
+ // TODO(b/150301488): Support more protocols as needed.
+ return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
}
}
if l.SrcAddr != nil {
@@ -275,9 +357,9 @@ func Address(v tcpip.Address) *tcpip.Address {
return &v
}
-// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and
+// parseIPv4 parses the bytes assuming that they start with an ipv4 header and
// continues parsing further encapsulations.
-func ParseIPv4(b []byte) (Layers, error) {
+func parseIPv4(b []byte) (Layer, layerParser) {
h := header.IPv4(b)
tos, _ := h.TOS()
ipv4 := IPv4{
@@ -293,16 +375,17 @@ func ParseIPv4(b []byte) (Layers, error) {
SrcAddr: Address(h.SourceAddress()),
DstAddr: Address(h.DestinationAddress()),
}
- layers := Layers{&ipv4}
- switch h.Protocol() {
- case uint8(header.TCPProtocolNumber):
- moreLayers, err := ParseTCP(b[ipv4.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ var nextParser layerParser
+ switch h.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ nextParser = parseTCP
+ case header.UDPProtocolNumber:
+ nextParser = parseUDP
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
}
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol())
+ return &ipv4, nextParser
}
func (l *IPv4) match(other Layer) bool {
@@ -316,7 +399,13 @@ func (l *IPv4) length() int {
return int(*l.IHL)
}
-// TCP can construct and match the TCP excapulation.
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// TCP can construct and match a TCP encapsulation.
type TCP struct {
LayerBase
SrcPort *uint16
@@ -330,6 +419,10 @@ type TCP struct {
UrgentPointer *uint16
}
+func (l *TCP) String() string {
+ return stringLayer(l)
+}
+
func (l *TCP) toBytes() ([]byte, error) {
b := make([]byte, header.TCPMinimumSize)
h := header.TCP(b)
@@ -347,12 +440,16 @@ func (l *TCP) toBytes() ([]byte, error) {
}
if l.DataOffset != nil {
h.SetDataOffset(*l.DataOffset)
+ } else {
+ h.SetDataOffset(uint8(l.length()))
}
if l.Flags != nil {
h.SetFlags(*l.Flags)
}
if l.WindowSize != nil {
h.SetWindowSize(*l.WindowSize)
+ } else {
+ h.SetWindowSize(32768)
}
if l.UrgentPointer != nil {
h.SetUrgentPoiner(*l.UrgentPointer)
@@ -361,38 +458,52 @@ func (l *TCP) toBytes() ([]byte, error) {
h.SetChecksum(*l.Checksum)
return h, nil
}
- if err := setChecksum(&h, l); err != nil {
+ if err := setTCPChecksum(&h, l); err != nil {
return nil, err
}
return h, nil
}
-// setChecksum calculates the checksum of the TCP header and sets it in h.
-func setChecksum(h *header.TCP, tcp *TCP) error {
- h.SetChecksum(0)
- tcpLength := uint16(tcp.length())
- current := tcp.next()
- for current != nil {
- tcpLength += uint16(current.length())
- current = current.next()
+// totalLength returns the length of the provided layer and all following
+// layers.
+func totalLength(l Layer) int {
+ var totalLength int
+ for ; l != nil; l = l.next() {
+ totalLength += l.length()
}
+ return totalLength
+}
+// layerChecksum calculates the checksum of the Layer header, including the
+// peusdeochecksum of the layer before it and all the bytes after it..
+func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
+ totalLength := uint16(totalLength(l))
var xsum uint16
- switch s := tcp.prev().(type) {
+ switch s := l.prev().(type) {
case *IPv4:
- xsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, *s.SrcAddr, *s.DstAddr, tcpLength)
+ xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
default:
// TODO(b/150301488): Support more protocols, like IPv6.
- return fmt.Errorf("can't get src and dst addr from previous layer")
+ return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
}
- current = tcp.next()
- for current != nil {
+ var payloadBytes buffer.VectorisedView
+ for current := l.next(); current != nil; current = current.next() {
payload, err := current.toBytes()
if err != nil {
- return fmt.Errorf("can't get bytes for next header: %s", payload)
+ return 0, fmt.Errorf("can't get bytes for next header: %s", payload)
}
- xsum = header.Checksum(payload, xsum)
- current = current.next()
+ payloadBytes.AppendView(payload)
+ }
+ xsum = header.ChecksumVV(payloadBytes, xsum)
+ return xsum, nil
+}
+
+// setTCPChecksum calculates the checksum of the TCP header and sets it in h.
+func setTCPChecksum(h *header.TCP, tcp *TCP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
+ if err != nil {
+ return err
}
h.SetChecksum(^h.CalculateChecksum(xsum))
return nil
@@ -404,9 +515,9 @@ func Uint32(v uint32) *uint32 {
return &v
}
-// ParseTCP parses the bytes assuming that they start with a tcp header and
+// parseTCP parses the bytes assuming that they start with a tcp header and
// continues parsing further encapsulations.
-func ParseTCP(b []byte) (Layers, error) {
+func parseTCP(b []byte) (Layer, layerParser) {
h := header.TCP(b)
tcp := TCP{
SrcPort: Uint16(h.SourcePort()),
@@ -419,12 +530,7 @@ func ParseTCP(b []byte) (Layers, error) {
Checksum: Uint16(h.Checksum()),
UrgentPointer: Uint16(h.UrgentPointer()),
}
- layers := Layers{&tcp}
- moreLayers, err := ParsePayload(b[tcp.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ return &tcp, parsePayload
}
func (l *TCP) match(other Layer) bool {
@@ -440,8 +546,86 @@ func (l *TCP) length() int {
// merge overrides the values in l with the values from other but only in fields
// where the value is not nil.
-func (l *TCP) merge(other TCP) error {
- return mergo.Merge(l, other, mergo.WithOverride)
+func (l *TCP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// UDP can construct and match a UDP encapsulation.
+type UDP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ Length *uint16
+ Checksum *uint16
+}
+
+func (l *UDP) String() string {
+ return stringLayer(l)
+}
+
+func (l *UDP) toBytes() ([]byte, error) {
+ b := make([]byte, header.UDPMinimumSize)
+ h := header.UDP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.Length != nil {
+ h.SetLength(*l.Length)
+ } else {
+ h.SetLength(uint16(totalLength(l)))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setUDPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// setUDPChecksum calculates the checksum of the UDP header and sets it in h.
+func setUDPChecksum(h *header.UDP, udp *UDP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// parseUDP parses the bytes assuming that they start with a udp header and
+// returns the parsed layer and the next parser to use.
+func parseUDP(b []byte) (Layer, layerParser) {
+ h := header.UDP(b)
+ udp := UDP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ Length: Uint16(h.Length()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &udp, parsePayload
+}
+
+func (l *UDP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *UDP) length() int {
+ if l.Length == nil {
+ return header.UDPMinimumSize
+ }
+ return int(*l.Length)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *UDP) merge(other Layer) error {
+ return mergeLayer(l, other)
}
// Payload has bytes beyond OSI layer 4.
@@ -450,13 +634,17 @@ type Payload struct {
Bytes []byte
}
-// ParsePayload parses the bytes assuming that they start with a payload and
+func (l *Payload) String() string {
+ return stringLayer(l)
+}
+
+// parsePayload parses the bytes assuming that they start with a payload and
// continue to the end. There can be no further encapsulations.
-func ParsePayload(b []byte) (Layers, error) {
+func parsePayload(b []byte) (Layer, layerParser) {
payload := Payload{
Bytes: b,
}
- return Layers{&payload}, nil
+ return &payload, nil
}
func (l *Payload) toBytes() ([]byte, error) {
@@ -471,18 +659,33 @@ func (l *Payload) length() int {
return len(l.Bytes)
}
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Payload) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
// Layers is an array of Layer and supports similar functions to Layer.
type Layers []Layer
-func (ls *Layers) toBytes() ([]byte, error) {
+// linkLayers sets the linked-list ponters in ls.
+func (ls *Layers) linkLayers() {
for i, l := range *ls {
if i > 0 {
l.setPrev((*ls)[i-1])
+ } else {
+ l.setPrev(nil)
}
if i+1 < len(*ls) {
l.setNext((*ls)[i+1])
+ } else {
+ l.setNext(nil)
}
}
+}
+
+func (ls *Layers) toBytes() ([]byte, error) {
+ ls.linkLayers()
outBytes := []byte{}
for _, l := range *ls {
layerBytes, err := l.toBytes()
@@ -498,8 +701,8 @@ func (ls *Layers) match(other Layers) bool {
if len(*ls) > len(other) {
return false
}
- for i := 0; i < len(*ls); i++ {
- if !equalLayer((*ls)[i], other[i]) {
+ for i, l := range *ls {
+ if !equalLayer(l, other[i]) {
return false
}
}