summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/flipcall/BUILD1
-rw-r--r--pkg/flipcall/flipcall.go2
-rw-r--r--pkg/flipcall/packet_window_mmap.go25
-rw-r--r--pkg/merkletree/BUILD16
-rw-r--r--pkg/merkletree/merkletree.go135
-rw-r--r--pkg/merkletree/merkletree_test.go122
-rw-r--r--test/packetimpact/testbench/connections.go211
-rw-r--r--test/packetimpact/tests/icmpv6_param_problem_test.go4
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go25
-rw-r--r--test/packetimpact/tests/udp_recv_multicast_test.go4
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go3
11 files changed, 437 insertions, 111 deletions
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index 9c5ad500b..aa8e4e1f3 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -11,6 +11,7 @@ go_library(
"futex_linux.go",
"io.go",
"packet_window_allocator.go",
+ "packet_window_mmap.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 3cdb576e1..ec742c091 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -95,7 +95,7 @@ func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...
if pwd.Length > math.MaxUint32 {
return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
}
- m, _, e := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ m, e := packetWindowMmap(pwd)
if e != 0 {
return fmt.Errorf("failed to mmap packet window: %v", e)
}
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap.go
new file mode 100644
index 000000000..869183b11
--- /dev/null
+++ b/pkg/flipcall/packet_window_mmap.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "syscall"
+)
+
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
new file mode 100644
index 000000000..5b0e4143a
--- /dev/null
+++ b/pkg/merkletree/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "merkletree",
+ srcs = ["merkletree.go"],
+ deps = ["//pkg/usermem"],
+)
+
+go_test(
+ name = "merkletree_test",
+ srcs = ["merkletree_test.go"],
+ library = ":merkletree",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
new file mode 100644
index 000000000..906f67943
--- /dev/null
+++ b/pkg/merkletree/merkletree.go
@@ -0,0 +1,135 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package merkletree implements Merkle tree generating and verification.
+package merkletree
+
+import (
+ "crypto/sha256"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // sha256DigestSize specifies the digest size of a SHA256 hash.
+ sha256DigestSize = 32
+)
+
+// Size defines the scale of a Merkle tree.
+type Size struct {
+ // blockSize is the size of a data block to be hashed.
+ blockSize int64
+ // digestSize is the size of a generated hash.
+ digestSize int64
+ // hashesPerBlock is the number of hashes in a block. For example, if
+ // blockSize is 4096 bytes, and digestSize is 32 bytes, there will be 128
+ // hashesPerBlock. Therefore 128 hashes in a lower level will be put into a
+ // block and generate a single hash in an upper level.
+ hashesPerBlock int64
+ // levelStart is the start block index of each level. The number of levels in
+ // the tree is the length of the slice. The leafs (level 0) are hashes of
+ // blocks in the input data. The levels above are hashes of lower level
+ // hashes. The highest level is the root hash.
+ levelStart []int64
+}
+
+// MakeSize initializes and returns a new Size object describing the structure
+// of a tree. dataSize specifies the number of the file system size in bytes.
+func MakeSize(dataSize int64) Size {
+ size := Size{
+ blockSize: usermem.PageSize,
+ // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
+ digestSize: sha256DigestSize,
+ hashesPerBlock: usermem.PageSize / sha256DigestSize,
+ }
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+ level := int64(0)
+ offset := int64(0)
+
+ // Calcuate the number of levels in the Merkle tree and the beginning offset
+ // of each level. Level 0 is the level directly above the data blocks, while
+ // level NumLevels - 1 is the root.
+ for numBlocks > 1 {
+ size.levelStart = append(size.levelStart, offset)
+ // Round numBlocks up to fill up a block.
+ numBlocks += (size.hashesPerBlock - numBlocks%size.hashesPerBlock) % size.hashesPerBlock
+ offset += numBlocks / size.hashesPerBlock
+ numBlocks = numBlocks / size.hashesPerBlock
+ level++
+ }
+ size.levelStart = append(size.levelStart, offset)
+ return size
+}
+
+// Generate constructs a Merkle tree for the contents of data. The output is
+// written to treeWriter. The treeReader should be able to read the tree after
+// it has been written. That is, treeWriter and treeReader should point to the
+// same underlying data but have separate cursors.
+func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter io.Writer) ([]byte, error) {
+ size := MakeSize(dataSize)
+
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+
+ var root []byte
+ for level := 0; level < len(size.levelStart); level++ {
+ for i := int64(0); i < numBlocks; i++ {
+ buf := make([]byte, size.blockSize)
+ var (
+ n int
+ err error
+ )
+ if level == 0 {
+ // Read data block from the target file since level 0 is directly above
+ // the raw data block.
+ n, err = data.Read(buf)
+ } else {
+ // Read data block from the tree file since levels higher than 0 are
+ // hashing the lower level hashes.
+ n, err = treeReader.Read(buf)
+ }
+
+ // err is populated as long as the bytes read is smaller than the buffer
+ // size. This could be the case if we are reading the last block, and
+ // break in that case. If this is the last block, the end of the block
+ // will be zero-padded.
+ if n == 0 && err == io.EOF {
+ break
+ } else if err != nil && err != io.EOF {
+ return nil, err
+ }
+ // Hash the bytes in buf.
+ digest := sha256.Sum256(buf)
+
+ if level == len(size.levelStart)-1 {
+ root = digest[:]
+ }
+
+ // Write the generated hash to the end of the tree file.
+ if _, err = treeWriter.Write(digest[:]); err != nil {
+ return nil, err
+ }
+ }
+ // If the genereated digests do not round up to a block, zero-padding the
+ // remaining of the last block. But no need to do so for root.
+ if level != len(size.levelStart)-1 && numBlocks%size.hashesPerBlock != 0 {
+ zeroBuf := make([]byte, size.blockSize-(numBlocks%size.hashesPerBlock)*size.digestSize)
+ if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ return nil, err
+ }
+ }
+ numBlocks = (numBlocks + size.hashesPerBlock - 1) / size.hashesPerBlock
+ }
+ return root, nil
+}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
new file mode 100644
index 000000000..7344db0b6
--- /dev/null
+++ b/pkg/merkletree/merkletree_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package merkletree
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSize(t *testing.T) {
+ testCases := []struct {
+ dataSize int64
+ expectedLevelStart []int64
+ }{
+ {
+ dataSize: 100,
+ expectedLevelStart: []int64{0},
+ },
+ {
+ dataSize: 1000000,
+ expectedLevelStart: []int64{0, 2, 3},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ expectedLevelStart: []int64{0, 32, 33},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ s := MakeSize(tc.dataSize)
+ if s.blockSize != int64(usermem.PageSize) {
+ t.Errorf("got blockSize %d, want %d", s.blockSize, usermem.PageSize)
+ }
+ if s.digestSize != sha256DigestSize {
+ t.Errorf("got digestSize %d, want %d", s.digestSize, sha256DigestSize)
+ }
+ if len(s.levelStart) != len(tc.expectedLevelStart) {
+ t.Errorf("got levels %d, want %d", len(s.levelStart), len(tc.expectedLevelStart))
+ }
+ for i := 0; i < len(s.levelStart) && i < len(tc.expectedLevelStart); i++ {
+ if s.levelStart[i] != tc.expectedLevelStart[i] {
+ t.Errorf("got levelStart[%d] %d, want %d", i, s.levelStart[i], tc.expectedLevelStart[i])
+ }
+ }
+ })
+ }
+}
+
+func TestGenerate(t *testing.T) {
+ // The input data has size dataSize. It starts with the data in startWith,
+ // and all other bytes are zeroes.
+ testCases := []struct {
+ dataSize int
+ startWith []byte
+ expectedRoot []byte
+ }{
+ {
+ dataSize: usermem.PageSize,
+ startWith: nil,
+ expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ },
+ {
+ dataSize: 128*usermem.PageSize + 1,
+ startWith: nil,
+ expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'a'},
+ expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'1'},
+ expectedRoot: []byte{74, 35, 103, 179, 176, 149, 254, 112, 42, 65, 104, 66, 119, 56, 133, 124, 228, 15, 65, 161, 150, 0, 117, 174, 242, 34, 115, 115, 218, 37, 3, 105},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ var (
+ data bytes.Buffer
+ tree bytes.Buffer
+ )
+
+ startSize := len(tc.startWith)
+ _, err := data.Write(tc.startWith)
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+ _, err = data.Write(make([]byte, tc.dataSize-startSize))
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+
+ root, err := Generate(&data, int64(tc.dataSize), &tree, &tree)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
+ }
+
+ if !bytes.Equal(root, tc.expectedRoot) {
+ t.Errorf("Unexpected root")
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 6e85d6fab..8b4a4d905 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -43,16 +43,17 @@ func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with
// the port if there is no error.
-func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
- fd, err = unix.Socket(domain, typ, 0)
+func pickPort(domain, typ int) (int, uint16, error) {
+ fd, err := unix.Socket(domain, typ, 0)
if err != nil {
- return -1, nil, err
+ return -1, 0, err
}
defer func() {
if err != nil {
err = multierr.Append(err, unix.Close(fd))
}
}()
+ var sa unix.Sockaddr
switch domain {
case unix.AF_INET:
var sa4 unix.SockaddrInet4
@@ -63,16 +64,20 @@ func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16())
sa = &sa6
default:
- return -1, nil, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
+ return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
}
if err = unix.Bind(fd, sa); err != nil {
- return -1, nil, err
+ return -1, 0, err
}
sa, err = unix.Getsockname(fd)
if err != nil {
- return -1, nil, err
+ return -1, 0, err
}
- return fd, sa, nil
+ port, err := portFromSockaddr(sa)
+ if err != nil {
+ return -1, 0, err
+ }
+ return fd, port, nil
}
// layerState stores the state of a layer of a connection.
@@ -266,14 +271,10 @@ func SeqNumValue(v seqnum.Value) *seqnum.Value {
}
// newTCPState creates a new TCPState.
-func newTCPState(domain int, out, in TCP) (*tcpState, unix.Sockaddr, error) {
- portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_STREAM)
- if err != nil {
- return nil, nil, err
- }
- localPort, err := portFromSockaddr(localAddr)
+func newTCPState(domain int, out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
if err != nil {
- return nil, nil, err
+ return nil, err
}
s := tcpState{
out: TCP{SrcPort: &localPort},
@@ -283,12 +284,12 @@ func newTCPState(domain int, out, in TCP) (*tcpState, unix.Sockaddr, error) {
finSent: false,
}
if err := s.out.merge(&out); err != nil {
- return nil, nil, err
+ return nil, err
}
if err := s.in.merge(&in); err != nil {
- return nil, nil, err
+ return nil, err
}
- return &s, localAddr, nil
+ return &s, nil
}
func (s *tcpState) outgoing() Layer {
@@ -374,14 +375,10 @@ type udpState struct {
var _ layerState = (*udpState)(nil)
// newUDPState creates a new udpState.
-func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
- portPickerFD, localAddr, err := pickPort(domain, unix.SOCK_DGRAM)
+func newUDPState(domain int, out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
if err != nil {
- return nil, nil, err
- }
- localPort, err := portFromSockaddr(localAddr)
- if err != nil {
- return nil, nil, err
+ return nil, err
}
s := udpState{
out: UDP{SrcPort: &localPort},
@@ -389,12 +386,12 @@ func newUDPState(domain int, out, in UDP) (*udpState, unix.Sockaddr, error) {
portPickerFD: portPickerFD,
}
if err := s.out.merge(&out); err != nil {
- return nil, nil, err
+ return nil, err
}
if err := s.in.merge(&in); err != nil {
- return nil, nil, err
+ return nil, err
}
- return &s, localAddr, nil
+ return &s, nil
}
func (s *udpState) outgoing() Layer {
@@ -429,7 +426,6 @@ type Connection struct {
layerStates []layerState
injector Injector
sniffer Sniffer
- localAddr unix.Sockaddr
t *testing.T
}
@@ -475,20 +471,45 @@ func (conn *Connection) Close() {
}
}
-// CreateFrame builds a frame for the connection with layer overriding defaults
-// of the innermost layer and additionalLayers added after it.
-func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+// CreateFrame builds a frame for the connection with defaults overriden
+// from the innermost layer out, and additionalLayers added after it.
+//
+// Note that overrideLayers can have a length that is less than the number
+// of layers in this connection, and in such cases the innermost layers are
+// overriden first. As an example, valid values of overrideLayers for a TCP-
+// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and
+// [Ethernet, IPv4, TCP].
+func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers {
var layersToSend Layers
- for _, s := range conn.layerStates {
- layersToSend = append(layersToSend, s.outgoing())
- }
- if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil {
- conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err)
+ for i, s := range conn.layerStates {
+ layer := s.outgoing()
+ // overrideLayers and conn.layerStates have their tails aligned, so
+ // to find the index we move backwards by the distance i is to the
+ // end.
+ if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 {
+ if err := layer.merge(overrideLayers[j]); err != nil {
+ conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
+ }
+ }
+ layersToSend = append(layersToSend, layer)
}
layersToSend = append(layersToSend, additionalLayers...)
return layersToSend
}
+// SendFrameStateless sends a frame without updating any of the layer states.
+//
+// This method is useful for sending out-of-band control messages such as
+// ICMP packets, where it would not make sense to update the transport layer's
+// state using the ICMP header.
+func (conn *Connection) SendFrameStateless(frame Layers) {
+ outBytes, err := frame.ToBytes()
+ if err != nil {
+ conn.t.Fatalf("can't build outgoing packet: %s", err)
+ }
+ conn.injector.Send(outBytes)
+}
+
// SendFrame sends a frame on the wire and updates the state of all layers.
func (conn *Connection) SendFrame(frame Layers) {
outBytes, err := frame.ToBytes()
@@ -509,10 +530,13 @@ func (conn *Connection) SendFrame(frame Layers) {
}
}
-// Send a packet with reasonable defaults. Potentially override the final layer
-// in the connection with the provided layer and add additionLayers.
-func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {
- conn.SendFrame(conn.CreateFrame(layer, additionalLayers...))
+// send sends a packet, possibly with layers of this connection overridden and
+// additional layers added.
+//
+// Types defined with Connection as the underlying type should expose
+// type-safe versions of this method.
+func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) {
+ conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...))
}
// recvFrame gets the next successfully parsed frame (of type Layers) within the
@@ -606,7 +630,7 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- tcpState, localAddr, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
+ tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
if err != nil {
t.Fatalf("can't make tcpState: %s", err)
}
@@ -623,7 +647,6 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
layerStates: []layerState{etherState, ipv4State, tcpState},
injector: injector,
sniffer: sniffer,
- localAddr: localAddr,
t: t,
}
}
@@ -705,7 +728,7 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur
// Send a packet with reasonable defaults. Potentially override the TCP layer in
// the connection with the provided layer and add additionLayers.
func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
- (*Connection)(conn).Send(&tcp, additionalLayers...)
+ (*Connection)(conn).send(Layers{&tcp}, additionalLayers...)
}
// Close frees associated resources held by the TCPIPv4 connection.
@@ -727,32 +750,48 @@ func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
return gotTCP, err
}
-func (conn *TCPIPv4) state() *tcpState {
- state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+func (conn *TCPIPv4) tcpState() *tcpState {
+ state, ok := conn.layerStates[2].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *TCPIPv4) ipv4State() *ipv4State {
+ state, ok := conn.layerStates[1].(*ipv4State)
if !ok {
- conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
}
return state
}
// RemoteSeqNum returns the next expected sequence number from the DUT.
func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
- return conn.state().remoteSeqNum
+ return conn.tcpState().remoteSeqNum
}
// LocalSeqNum returns the next sequence number to send from the testbench.
func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
- return conn.state().localSeqNum
+ return conn.tcpState().localSeqNum
}
// SynAck returns the SynAck that was part of the handshake.
func (conn *TCPIPv4) SynAck() *TCP {
- return conn.state().synAck
+ return conn.tcpState().synAck
}
// LocalAddr gets the local socket address of this connection.
-func (conn *TCPIPv4) LocalAddr() unix.Sockaddr {
- return conn.localAddr
+func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 {
+ sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+ return sa
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
}
// IPv6Conn maintains the state for all the layers in a IPv6 connection.
@@ -786,15 +825,10 @@ func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
}
}
-// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *IPv6Conn) SendFrame(frame Layers) {
- (*Connection)(conn).SendFrame(frame)
-}
-
-// CreateFrame builds a frame for the connection with ipv6 overriding the ipv6
-// layer defaults and additionalLayers added after it.
-func (conn *IPv6Conn) CreateFrame(ipv6 IPv6, additionalLayers ...Layer) Layers {
- return (*Connection)(conn).CreateFrame(&ipv6, additionalLayers...)
+// Send sends a frame with ipv6 overriding the IPv6 layer defaults and
+// additionalLayers added after it.
+func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...)
}
// Close to clean up any resources held.
@@ -808,12 +842,6 @@ func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers,
return (*Connection)(conn).ExpectFrame(frame, timeout)
}
-// Drain drains the sniffer's receive buffer by receiving packets until there's
-// nothing else to receive.
-func (conn *TCPIPv4) Drain() {
- conn.sniffer.Drain()
-}
-
// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
type UDPIPv4 Connection
@@ -827,7 +855,7 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
if err != nil {
t.Fatalf("can't make ipv4State: %s", err)
}
- udpState, localAddr, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
+ udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
if err != nil {
t.Fatalf("can't make udpState: %s", err)
}
@@ -844,42 +872,43 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
layerStates: []layerState{etherState, ipv4State, udpState},
injector: injector,
sniffer: sniffer,
- localAddr: localAddr,
t: t,
}
}
-// LocalAddr gets the local socket address of this connection.
-func (conn *UDPIPv4) LocalAddr() unix.Sockaddr {
- return conn.localAddr
+func (conn *UDPIPv4) udpState() *udpState {
+ state, ok := conn.layerStates[2].(*udpState)
+ if !ok {
+ conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ }
+ return state
}
-// CreateFrame builds a frame for the connection with layer overriding defaults
-// of the innermost layer and additionalLayers added after it.
-func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
- return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
+func (conn *UDPIPv4) ipv4State() *ipv4State {
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ }
+ return state
}
-// Send a packet with reasonable defaults. Potentially override the UDP layer in
-// the connection with the provided layer and add additionLayers.
-func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
- (*Connection)(conn).Send(&udp, additionalLayers...)
+// LocalAddr gets the local socket address of this connection.
+func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 {
+ sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+ return sa
}
-// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *UDPIPv4) SendFrame(frame Layers) {
- (*Connection)(conn).SendFrame(frame)
+// Send sends a packet with reasonable defaults, potentially overriding the UDP
+// layer and adding additionLayers.
+func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&udp}, additionalLayers...)
}
-// SendIP sends a packet with additionalLayers following the IP layer in the
-// connection.
-func (conn *UDPIPv4) SendIP(additionalLayers ...Layer) {
- var layersToSend Layers
- for _, s := range conn.layerStates[:len(conn.layerStates)-1] {
- layersToSend = append(layersToSend, s.outgoing())
- }
- layersToSend = append(layersToSend, additionalLayers...)
- conn.SendFrame(layersToSend)
+// SendIP sends a packet with reasonable defaults, potentially overriding the
+// UDP and IPv4 headers and adding additionLayers.
+func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) {
+ (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...)
}
// Expect expects a frame with the UDP layer matching the provided UDP within
diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go
index 961059fc1..4d1d9a7f5 100644
--- a/test/packetimpact/tests/icmpv6_param_problem_test.go
+++ b/test/packetimpact/tests/icmpv6_param_problem_test.go
@@ -45,8 +45,8 @@ func TestICMPv6ParamProblemTest(t *testing.T) {
NDPPayload: []byte("hello world"),
}
- toSend := conn.CreateFrame(ipv6, &icmpv6)
- conn.SendFrame(toSend)
+ toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6)
+ (*testbench.Connection)(&conn).SendFrame(toSend)
// Build the expected ICMPv6 payload, which includes an index to the
// problematic byte and also the problematic packet as described in
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
index aedabf9de..b754918f6 100644
--- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -96,24 +96,25 @@ func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno {
// sendICMPError sends an ICMP error message in response to a UDP datagram.
func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error {
+ layers := (*testbench.Connection)(conn).CreateFrame(nil)
+ layers = layers[:len(layers)-1]
+ ip, ok := udp.Prev().(*testbench.IPv4)
+ if !ok {
+ return fmt.Errorf("expected %s to be IPv4", udp.Prev())
+ }
if icmpErr == timeToLiveExceeded {
- ip, ok := udp.Prev().(*testbench.IPv4)
- if !ok {
- return fmt.Errorf("expected %s to be IPv4", udp.Prev())
- }
*ip.TTL = 1
// Let serialization recalculate the checksum since we set the TTL
// to 1.
ip.Checksum = nil
-
- // Note that the ICMP payload is valid in this case because the UDP
- // payload is empty. If the UDP payload were not empty, the packet
- // length during serialization may not be calculated correctly,
- // resulting in a mal-formed packet.
- conn.SendIP(icmpErr.ToICMPv4(), ip, udp)
- } else {
- conn.SendIP(icmpErr.ToICMPv4(), udp.Prev(), udp)
}
+ // Note that the ICMP payload is valid in this case because the UDP
+ // payload is empty. If the UDP payload were not empty, the packet
+ // length during serialization may not be calculated correctly,
+ // resulting in a mal-formed packet.
+ layers = append(layers, icmpErr.ToICMPv4(), ip, udp)
+
+ (*testbench.Connection)(conn).SendFrameStateless(layers)
return nil
}
diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go
index d51a34145..77a9bfa1d 100644
--- a/test/packetimpact/tests/udp_recv_multicast_test.go
+++ b/test/packetimpact/tests/udp_recv_multicast_test.go
@@ -35,8 +35,6 @@ func TestUDPRecvMulticast(t *testing.T) {
defer dut.Close(boundFD)
conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
defer conn.Close()
- frame := conn.CreateFrame(&testbench.UDP{}, &testbench.Payload{Bytes: []byte("hello world")})
- frame[1].(*testbench.IPv4).DstAddr = testbench.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))
- conn.SendFrame(frame)
+ conn.SendIP(testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))}, testbench.UDP{})
dut.Recv(boundFD, 100, 0)
}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
index bf64803e2..a7db384ad 100644
--- a/test/packetimpact/tests/udp_send_recv_dgram_test.go
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -59,8 +59,7 @@ func TestUDPRecv(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- frame := conn.CreateFrame(&testbench.UDP{}, &testbench.Payload{Bytes: []byte(tc.payload)})
- conn.SendFrame(frame)
+ conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: []byte(tc.payload)})
if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), tc.payload; got != want {
t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want)
}