diff options
Diffstat (limited to 'test/packetimpact')
43 files changed, 8792 insertions, 0 deletions
diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md new file mode 100644 index 000000000..f46c67a0c --- /dev/null +++ b/test/packetimpact/README.md @@ -0,0 +1,702 @@ +# Packetimpact + +## What is packetimpact? + +Packetimpact is a tool for platform-independent network testing. It is heavily +inspired by [packetdrill](https://github.com/google/packetdrill). It creates two +docker containers connected by a network. One is for the test bench, which +operates the test. The other is for the device-under-test (DUT), which is the +software being tested. The test bench communicates over the network with the DUT +to check correctness of the network. + +### Goals + +Packetimpact aims to provide: + +* A **multi-platform** solution that can test both Linux and gVisor. +* **Conciseness** on par with packetdrill scripts. +* **Control-flow** like for loops, conditionals, and variables. +* **Flexibilty** to specify every byte in a packet or use multiple sockets. + +## How to run packetimpact tests? + +Build the test container image by running the following at the root of the +repository: + +```bash +$ make load-packetimpact +``` + +Run a test, e.g. `fin_wait2_timeout`, against Linux: + +```bash +$ bazel test //test/packetimpact/tests:fin_wait2_timeout_linux_test +``` + +Run the same test, but against gVisor: + +```bash +$ bazel test //test/packetimpact/tests:fin_wait2_timeout_netstack_test +``` + +## When to use packetimpact? + +There are a few ways to write networking tests for gVisor currently: + +* [Go unit tests](https://github.com/google/gvisor/tree/master/pkg/tcpip) +* [syscall tests](https://github.com/google/gvisor/tree/master/test/syscalls/linux) +* [packetdrill tests](https://github.com/google/gvisor/tree/master/test/packetdrill) +* packetimpact tests + +The right choice depends on the needs of the test. + +Feature | Go unit test | syscall test | packetdrill | packetimpact +-------------- | ------------ | ------------ | ----------- | ------------ +Multi-platform | no | **YES** | **YES** | **YES** +Concise | no | somewhat | somewhat | **VERY** +Control-flow | **YES** | **YES** | no | **YES** +Flexible | **VERY** | no | somewhat | **VERY** + +### Go unit tests + +If the test depends on the internals of gVisor and doesn't need to run on Linux +or other platforms for comparison purposes, a Go unit test can be appropriate. +They can observe internals of gVisor networking. The downside is that they are +**not concise** and **not multi-platform**. If you require insight on gVisor +internals, this is the right choice. + +### Syscall tests + +Syscall tests are **multi-platform** but cannot examine the internals of gVisor +networking. They are **concise**. They can use **control-flow** structures like +conditionals, for loops, and variables. However, they are limited to only what +the POSIX interface provides so they are **not flexible**. For example, you +would have difficulty writing a syscall test that intentionally sends a bad IP +checksum. Or if you did write that test with raw sockets, it would be very +**verbose** to write a test that intentionally send wrong checksums, wrong +protocols, wrong sequence numbers, etc. + +### Packetdrill tests + +Packetdrill tests are **multi-platform** and can run against both Linux and +gVisor. They are **concise** and use a special packetdrill scripting language. +They are **more flexible** than a syscall test in that they can send packets +that a syscall test would have difficulty sending, like a packet with a +calcuated ACK number. But they are also somewhat limimted in flexibiilty in that +they can't do tests with multiple sockets. They have **no control-flow** ability +like variables or conditionals. For example, it isn't possible to send a packet +that depends on the window size of a previous packet because the packetdrill +language can't express that. Nor could you branch based on whether or not the +other side supports window scaling, for example. + +### Packetimpact tests + +Packetimpact tests are similar to Packetdrill tests except that they are written +in Go instead of the packetdrill scripting language. That gives them all the +**control-flow** abilities of Go (loops, functions, variables, etc). They are +**multi-platform** in the same way as packetdrill tests but even more +**flexible** because Go is more expressive than the scripting language of +packetdrill. However, Go is **not as concise** as the packetdrill language. Many +design decisions below are made to mitigate that. + +## How it works + +``` + Testbench Device-Under-Test (DUT) + +-------------------+ +------------------------+ + | | TEST NET | | + | rawsockets.go <-->| <===========> | <---+ | + | ^ | | | | + | | | | | | + | v | | | | + | unittest | | | | + | ^ | | | | + | | | | | | + | v | | v | + | dut.go <========gRPC========> posix server | + | | CONTROL NET | | + +-------------------+ +------------------------+ +``` + +Two docker containers are created by a "runner" script, one for the testbench +and the other for the device under test (DUT). The script connects the two +containers with a control network and test network. It also does some other +tasks like waiting until the DUT is ready before starting the test and disabling +Linux networking that would interfere with the test bench. + +### DUT + +The DUT container runs a program called the "posix_server". The posix_server is +written in c++ for maximum portability. It is compiled on the host. The script +that starts the containers copies it into the DUT's container and runs it. It's +job is to receive directions from the test bench on what actions to take. For +this, the posix_server does three steps in a loop: + +1. Listen for a request from the test bench. +2. Execute a command. +3. Send the response back to the test bench. + +The requests and responses are +[protobufs](https://developers.google.com/protocol-buffers) and the +communication is done with [gRPC](https://grpc.io/). The commands run are +[POSIX socket commands](https://en.wikipedia.org/wiki/Berkeley_sockets#Socket_API_functions), +with the inputs and outputs converted into protobuf requests and responses. All +communication is on the control network, so that the test network is unaffected +by extra packets. + +For example, this is the request and response pair to call +[`socket()`](http://man7.org/linux/man-pages/man2/socket.2.html): + +```protocol-buffer +message SocketRequest { + int32 domain = 1; + int32 type = 2; + int32 protocol = 3; +} + +message SocketResponse { + int32 fd = 1; + int32 errno_ = 2; +} +``` + +##### Alternatives considered + +* We could have use JSON for communication instead. It would have been a + lighter-touch than protobuf but protobuf handles all the data type and has + strict typing to prevent a class of errors. The test bench could be written + in other languages, too. +* Instead of mimicking the POSIX interfaces, arguments could have had a more + natural form, like the `bind()` getting a string IP address instead of bytes + in a `sockaddr_t`. However, conforming to the existing structures keeps more + of the complexity in Go and keeps the posix_server simpler and thus more + likely to compile everywhere. + +### Test Bench + +The test bench does most of the work in a test. It is a Go program that compiles +on the host and is copied by the script into test bench's container. It is a +regular [go unit test](https://golang.org/pkg/testing/) that imports the test +bench framework. The test bench framwork is based on three basic utilities: + +* Commanding the DUT to run POSIX commands and return responses. +* Sending raw packets to the DUT on the test network. +* Listening for raw packets from the DUT on the test network. + +#### DUT commands + +To keep the interface to the DUT consistent and easy-to-use, each POSIX command +supported by the posix_server is wrapped in functions with signatures similar to +the ones in the [Go unix package](https://godoc.org/golang.org/x/sys/unix). This +way all the details of endianess and (un)marshalling of go structs such as +[unix.Timeval](https://godoc.org/golang.org/x/sys/unix#Timeval) is handled in +one place. This also makes it straight-forward to convert tests that use `unix.` +or `syscall.` calls to `dut.` calls. + +For example, creating a connection to the DUT and commanding it to make a socket +looks like this: + +```go +dut := testbench.NewDut(t) +fd, err := dut.SocketWithErrno(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP) +if fd < 0 { + t.Fatalf(...) +} +``` + +Because the usual case is to fail the test when the DUT fails to create a +socket, there is a concise version of each of the `...WithErrno` functions that +does that: + +```go +dut := testbench.NewDut(t) +fd := dut.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP) +``` + +The DUT and other structs in the code store a `*testing.T` so that they can +provide versions of functions that call `t.Fatalf(...)`. This helps keep tests +concise. + +##### Alternatives considered + +* Instead of mimicking the `unix.` go interface, we could have invented a more + natural one, like using `float64` instead of `Timeval`. However, using the + same function signatures that `unix.` has makes it easier to convert code to + `dut.`. Also, using an existing interface ensures that we don't invent an + interface that isn't extensible. For example, if we invented a function for + `bind()` that didn't support IPv6 and later we had to add a second `bind6()` + function. + +#### Sending/Receiving Raw Packets + +The framework wraps POSIX sockets for sending and receiving raw frames. Both +send and receive are synchronous commands. +[SO_RCVTIMEO](http://man7.org/linux/man-pages/man7/socket.7.html) is used to set +a timeout on the receive commands. For ease of use, these are wrapped in an +`Injector` and a `Sniffer`. They have functions: + +```go +func (s *Sniffer) Recv(timeout time.Duration) []byte {...} +func (i *Injector) Send(b []byte) {...} +``` + +##### Alternatives considered + +* [gopacket](https://github.com/google/gopacket) pcap has raw socket support + but requires cgo. cgo is not guaranteed to be portable from the host to the + container and in practice, the container doesn't recognize binaries built on + the host if they use cgo. +* Both gVisor and gopacket have the ability to read and write pcap files + without cgo but that is insufficient here because we can't just replay pcap + files, we need a more dynamic solution. +* The sniffer and injector can't share a socket because they need to be bound + differently. +* Sniffing could have been done asynchronously with channels, obviating the + need for `SO_RCVTIMEO`. But that would introduce asynchronous complication. + `SO_RCVTIMEO` is well supported on the test bench. + +#### `Layer` struct + +A large part of packetimpact tests is creating packets to send and comparing +received packets against expectations. To keep tests concise, it is useful to be +able to specify just the important parts of packets that need to be set. For +example, sending a packet with default values except for TCP Flags. And for +packets received, it's useful to be able to compare just the necessary parts of +received packets and ignore the rest. + +To aid in both of those, Go structs with optional fields are created for each +encapsulation type, such as IPv4, TCP, and Ethernet. This is inspired by +[scapy](https://scapy.readthedocs.io/en/latest/). For example, here is the +struct for Ethernet: + +```go +type Ether struct { + LayerBase + SrcAddr *tcpip.LinkAddress + DstAddr *tcpip.LinkAddress + Type *tcpip.NetworkProtocolNumber +} +``` + +Each struct has the same fields as those in the +[gVisor headers](https://github.com/google/gvisor/tree/master/pkg/tcpip/header) +but with a pointer for each field that may be `nil`. + +##### Alternatives considered + +* Just use []byte like gVisor headers do. The drawback is that it makes the + tests more verbose. + * For example, there would be no way to call `Send(myBytes)` concisely and + indicate if the checksum should be calculated automatically versus + overridden. The only way would be to add lines to the test to calculate + it before each Send, which is wordy. Or make multiple versions of Send: + one that checksums IP, one that doesn't, one that checksums TCP, one + that does both, etc. That would be many combinations. + * Filtering inputs would become verbose. Either: + * large conditionals that need to be repeated many places: + `h[FlagOffset] == SYN && h[LengthOffset:LengthOffset+2] == ...` or + * Many functions, one per field, like: `filterByFlag(myBytes, SYN)`, + `filterByLength(myBytes, 20)`, `filterByNextProto(myBytes, 0x8000)`, + etc. + * Using pointers allows us to combine `Layer`s with reflection. So the + default `Layers` can be overridden by a `Layers` with just the TCP + conection's src/dst which can be overridden by one with just a test + specific TCP window size. + * It's a proven way to separate the details of a packet from the byte + format as shown by scapy's success. +* Use packetgo. It's more general than parsing packets with gVisor. However: + * packetgo doesn't have optional fields so many of the above problems + still apply. + * It would be yet another dependency. + * It's not as well known to engineers that are already writing gVisor + code. + * It might be a good candidate for replacing the parsing of packets into + `Layer`s if all that parsing turns out to be more work than parsing by + packetgo and converting *that* to `Layer`. packetgo has easier to use + getters for the layers. This could be done later in a way that doesn't + break tests. + +#### `Layer` methods + +The `Layer` structs provide a way to partially specify an encapsulation. They +also need methods for using those partially specified encapsulation, for example +to marshal them to bytes or compare them. For those, each encapsulation +implements the `Layer` interface: + +```go +// Layer is the interface that all encapsulations must implement. +// +// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A +// Layer contains all the fields of the encapsulation. Each field is a pointer +// and may be nil. +type Layer interface { + // toBytes converts the Layer into bytes. In places where the Layer's field + // isn't nil, the value that is pointed to is used. When the field is nil, a + // reasonable default for the Layer is used. For example, "64" for IPv4 TTL + // and a calculated checksum for TCP or IP. Some layers require information + // from the previous or next layers in order to compute a default, such as + // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list + // to the layer's neighbors. + toBytes() ([]byte, error) + + // match checks if the current Layer matches the provided Layer. If either + // Layer has a nil in a given field, that field is considered matching. + // Otherwise, the values pointed to by the fields must match. + match(Layer) bool + + // length in bytes of the current encapsulation + length() int + + // next gets a pointer to the encapsulated Layer. + next() Layer + + // prev gets a pointer to the Layer encapsulating this one. + prev() Layer + + // setNext sets the pointer to the encapsulated Layer. + setNext(Layer) + + // setPrev sets the pointer to the Layer encapsulating this one. + setPrev(Layer) +} +``` + +The `next` and `prev` make up a link listed so that each layer can get at the +information in the layer around it. This is necessary for some protocols, like +TCP that needs the layer before and payload after to compute the checksum. Any +sequence of `Layer` structs is valid so long as the parser and `toBytes` +functions can map from type to protool number and vice-versa. When the mapping +fails, an error is emitted explaining what functionality is missing. The +solution is either to fix the ordering or implement the missing protocol. + +For each `Layer` there is also a parsing function. For example, this one is for +Ethernet: + +``` +func ParseEther(b []byte) (Layers, error) +``` + +The parsing function converts bytes received on the wire into a `Layer` +(actually `Layers`, see below) which has no `nil`s in it. By using +`match(Layer)` to compare against another `Layer` that *does* have `nil`s in it, +the received bytes can be partially compared. The `nil`s behave as +"don't-cares". + +##### Alternatives considered + +* Matching against `[]byte` instead of converting to `Layer` first. + * The downside is that it precludes the use of a `cmp.Equal` one-liner to + do comparisons. + * It creates confusion in the code to deal with both representations at + different times. For example, is the checksum calculated on `[]byte` or + `Layer` when sending? What about when checking received packets? + +#### `Layers` + +``` +type Layers []Layer + +func (ls *Layers) match(other Layers) bool {...} +func (ls *Layers) toBytes() ([]byte, error) {...} +``` + +`Layers` is an array of `Layer`. It represents a stack of encapsulations, such +as `Layers{Ether{},IPv4{},TCP{},Payload{}}`. It also has `toBytes()` and +`match(Layers)`, like `Layer`. The parse functions above actually return +`Layers` and not `Layer` because they know about the headers below and +sequentially call each parser on the remaining, encapsulated bytes. + +All this leads to the ability to write concise packet processing. For example: + +```go +etherType := 0x8000 +flags = uint8(header.TCPFlagSyn|header.TCPFlagAck) +toMatch := Layers{Ether{Type: ðerType}, IPv4{}, TCP{Flags: &flags}} +for { + recvBytes := sniffer.Recv(time.Second) + if recvBytes == nil { + println("Got no packet for 1 second") + } + gotPacket, err := ParseEther(recvBytes) + if err == nil && toMatch.match(gotPacket) { + println("Got a TCP/IPv4/Eth packet with SYNACK") + } +} +``` + +##### Alternatives considered + +* Don't use previous and next pointers. + * Each layer may need to be able to interrogate the layers around it, like + for computing the next protocol number or total length. So *some* + mechanism is needed for a `Layer` to see neighboring layers. + * We could pass the entire array `Layers` to the `toBytes()` function. + Passing an array to a method that includes in the array the function + receiver itself seems wrong. + +#### `layerState` + +`Layers` represents the different headers of a packet but a connection includes +more state. For example, a TCP connection needs to keep track of the next +expected sequence number and also the next sequence number to send. This is +stored in a `layerState` struct. This is the `layerState` for TCP: + +```go +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} +``` + +The next sequence numbers for each side of the connection are stored. `out` and +`in` have defaults for the TCP header, such as the expected source and +destination ports for outgoing packets and incoming packets. + +##### `layerState` interface + +```go +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. + outgoing() Layer + + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error +} +``` + +`outgoing` generates the default Layer for an outgoing packet. For TCP, this +would be a `TCP` with the source and destination ports populated. Because they +are static, they are stored inside the `out` member of `tcpState`. However, the +sequence numbers change frequently so the outgoing sequence number is stored in +the `localSeqNum` and put into the output of outgoing for each call. + +`incoming` does the same functions for packets that arrive but instead of +generating a packet to send, it generates an expect packet for filtering packets +that arrive. For example, if a `TCP` header arrives with the wrong ports, it can +be ignored as belonging to a different connection. `incoming` needs the received +header itself as an input because the filter may depend on the input. For +example, the expected sequence number depends on the flags in the TCP header. + +`sent` and `received` are run for each header that is actually sent or received +and used to update the internal state. `incoming` and `outgoing` should *not* be +used for these purpose. For example, `incoming` is called on every packet that +arrives but only packets that match ought to actually update the state. +`outgoing` is called to created outgoing packets and those packets are always +sent, so unlike `incoming`/`received`, there is one `outgoing` call for each +`sent` call. + +`close` cleans up after the layerState. For example, TCP and UDP need to keep a +port reserved and then release it. + +#### Connections + +Using `layerState` above, we can create connections. + +```go +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer + t *testing.T +} +``` + +The connection stores an array of `layerState` in the order that the headers +should be present in the frame to send. For example, Ether then IPv4 then TCP. +The injector and sniffer are for writing and reading frames. A `*testing.T` is +stored so that internal errors can be reported directly without code in the unit +test. + +The `Connection` has some useful functions: + +```go +// Close frees associated resources held by the Connection. +func (conn *Connection) Close() {...} +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {...} +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(frame Layers) {...} +// Send a packet with reasonable defaults. Potentially override the final layer +// in the connection with the provided layer and add additionLayers. +func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {...} +// Expect a frame with the final layerStates layer matching the provided Layer +// within the timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {...} +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {...} +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain() {...} +``` + +`CreateFrame` uses the `[]layerState` to create a frame to send. The first +argument is for overriding defaults in the last header of the frame, because +this is the most common need. For a TCPIPv4 connection, this would be the TCP +header. Optional additionalLayers can be specified to add to the frame being +created, such as a `Payload` for `TCP`. + +`SendFrame` sends the frame to the DUT. It is combined with `CreateFrame` to +make `Send`. For unittests with basic sending needs, `Send` can be used. If more +control is needed over the frame, it can be made with `CreateFrame`, modified in +the unit test, and then sent with `SendFrame`. + +On the receiving side, there is `Expect` and `ExpectFrame`. Like with the +sending side, there are two forms of each function, one for just the last header +and one for the whole frame. The expect functions use the `[]layerState` to +create a template for the expected incoming frame. That frame is then overridden +by the values in the first argument. Finally, a loop starts sniffing packets on +the wire for frames. If a matching frame is found before the timeout, it is +returned without error. If not, nil is returned and the error contains text of +all the received frames that didn't match. Exactly one of the outputs will be +non-nil, even if no frames are received at all. + +`Drain` sniffs and discards all the frames that have yet to be received. A +common way to write a test is: + +```go +conn.Drain() // Discard all outstanding frames. +conn.Send(...) // Send a frame with overrides. +// Now expect a frame with a certain header and fail if it doesn't arrive. +if _, err := conn.Expect(...); err != nil { t.Fatal(...) } +``` + +Or for a test where we want to check that no frame arrives: + +```go +if gotOne, _ := conn.Expect(...); gotOne != nil { t.Fatal(...) } +``` + +#### Specializing `Connection` + +Because there are some common combinations of `layerState` into `Connection`, +they are defined: + +```go +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection +``` + +Each has a `NewXxx` function to create a new connection with reasonable +defaults. They also have functions that call the underlying `Connection` +functions but with specialization and tighter type-checking. For example: + +```go +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + (*Connection)(conn).Send(&tcp, additionalLayers...) +} +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} +``` + +They may also have some accessors to get or set the internal state of the +connection: + +```go +func (conn *TCPIPv4) state() *tcpState { + state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState) + if !ok { + conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates) + } + return state +} +func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { + return conn.state().remoteSeqNum +} +func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { + return conn.state().localSeqNum +} +``` + +Unittests will in practice use these functions and not the functions on +`Connection`. For example, `NewTCPIPv4()` and then call `Send` on that rather +than cast is to a `Connection` and call `Send` on that cast result. + +##### Alternatives considered + +* Instead of storing `outgoing` and `incoming`, store values. + * There would be many more things to store instead, like `localMac`, + `remoteMac`, `localIP`, `remoteIP`, `localPort`, and `remotePort`. + * Construction of a packet would be many lines to copy each of these + values into a `[]byte`. And there would be slight variations needed for + each encapsulation stack, like TCPIPv6 and ARP. + * Filtering incoming packets would be a long sequence: + * Compare the MACs, then + * Parse the next header, then + * Compare the IPs, then + * Parse the next header, then + * Compare the TCP ports. Instead it's all just one call to + `cmp.Equal(...)`, for all sequences. + * A TCPIPv6 connection could share most of the code. Only the type of the + IP addresses are different. The types of `outgoing` and `incoming` would + be remain `Layers`. + * An ARP connection could share all the Ethernet parts. The IP `Layer` + could be factored out of `outgoing`. After that, the IPv4 and IPv6 + connections could implement one interface and a single TCP struct could + have either network protocol through composition. + +## Putting it all together + +Here's what te start of a packetimpact unit test looks like. This test creates a +TCP connection with the DUT. There are added comments for explanation in this +document but a real test might not include them in order to stay even more +concise. + +```go +func TestMyTcpTest(t *testing.T) { + // Prepare a DUT for communication. + dut := testbench.NewDUT(t) + + // This does: + // dut.Socket() + // dut.Bind() + // dut.Getsockname() to learn the new port number + // dut.Listen() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) // Tell the DUT to close the socket at the end of the test. + + // Monitor a new TCP connection with sniffer, injector, sequence number tracking, + // and reasonable outgoing and incoming packet field default IPs, MACs, and port numbers. + conn := testbench.NewTCPIPv4(t, dut, remotePort) + + // Perform a 3-way handshake: send SYN, expect SYNACK, send ACK. + conn.Handshake() + + // Tell the DUT to accept the new connection. + acceptFD := dut.Accept(acceptFd) +} +``` + +## Other notes + +* The time between receiving a SYN-ACK and replying with an ACK in `Handshake` + is about 3ms. This is much slower than the native unix response, which is + about 0.3ms. Packetdrill gets closer to 0.3ms. For tests where timing is + crucial, packetdrill is faster and more precise. diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD new file mode 100644 index 000000000..3ce63c2c6 --- /dev/null +++ b/test/packetimpact/dut/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "cc_binary", "grpcpp") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +cc_binary( + name = "posix_server", + srcs = ["posix_server.cc"], + linkstatic = 1, + static = True, # This is needed for running in a docker container. + deps = [ + grpcpp, + "//test/packetimpact/proto:posix_server_cc_grpc_proto", + "//test/packetimpact/proto:posix_server_cc_proto", + ], +) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc new file mode 100644 index 000000000..a1a5c3612 --- /dev/null +++ b/test/packetimpact/dut/posix_server.cc @@ -0,0 +1,365 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <fcntl.h> +#include <getopt.h> +#include <netdb.h> +#include <netinet/in.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <iostream> +#include <unordered_map> + +#include "include/grpcpp/security/server_credentials.h" +#include "include/grpcpp/server_builder.h" +#include "test/packetimpact/proto/posix_server.grpc.pb.h" +#include "test/packetimpact/proto/posix_server.pb.h" + +// Converts a sockaddr_storage to a Sockaddr message. +::grpc::Status sockaddr_to_proto(const sockaddr_storage &addr, + socklen_t addrlen, + posix_server::Sockaddr *sockaddr_proto) { + switch (addr.ss_family) { + case AF_INET: { + auto addr_in = reinterpret_cast<const sockaddr_in *>(&addr); + auto response_in = sockaddr_proto->mutable_in(); + response_in->set_family(addr_in->sin_family); + response_in->set_port(ntohs(addr_in->sin_port)); + response_in->mutable_addr()->assign( + reinterpret_cast<const char *>(&addr_in->sin_addr.s_addr), 4); + return ::grpc::Status::OK; + } + case AF_INET6: { + auto addr_in6 = reinterpret_cast<const sockaddr_in6 *>(&addr); + auto response_in6 = sockaddr_proto->mutable_in6(); + response_in6->set_family(addr_in6->sin6_family); + response_in6->set_port(ntohs(addr_in6->sin6_port)); + response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo)); + response_in6->mutable_addr()->assign( + reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16); + response_in6->set_scope_id(ntohl(addr_in6->sin6_scope_id)); + return ::grpc::Status::OK; + } + } + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr"); +} + +::grpc::Status proto_to_sockaddr(const posix_server::Sockaddr &sockaddr_proto, + sockaddr_storage *addr, socklen_t *addr_len) { + switch (sockaddr_proto.sockaddr_case()) { + case posix_server::Sockaddr::SockaddrCase::kIn: { + auto proto_in = sockaddr_proto.in(); + if (proto_in.addr().size() != 4) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "IPv4 address must be 4 bytes"); + } + auto addr_in = reinterpret_cast<sockaddr_in *>(addr); + addr_in->sin_family = proto_in.family(); + addr_in->sin_port = htons(proto_in.port()); + proto_in.addr().copy(reinterpret_cast<char *>(&addr_in->sin_addr.s_addr), + 4); + *addr_len = sizeof(*addr_in); + break; + } + case posix_server::Sockaddr::SockaddrCase::kIn6: { + auto proto_in6 = sockaddr_proto.in6(); + if (proto_in6.addr().size() != 16) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "IPv6 address must be 16 bytes"); + } + auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(addr); + addr_in6->sin6_family = proto_in6.family(); + addr_in6->sin6_port = htons(proto_in6.port()); + addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo()); + proto_in6.addr().copy( + reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16); + addr_in6->sin6_scope_id = htonl(proto_in6.scope_id()); + *addr_len = sizeof(*addr_in6); + break; + } + case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET: + default: + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Unknown Sockaddr"); + } + return ::grpc::Status::OK; +} + +class PosixImpl final : public posix_server::Posix::Service { + ::grpc::Status Accept(grpc_impl::ServerContext *context, + const ::posix_server::AcceptRequest *request, + ::posix_server::AcceptResponse *response) override { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + response->set_fd(accept(request->sockfd(), + reinterpret_cast<sockaddr *>(&addr), &addrlen)); + response->set_errno_(errno); + return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); + } + + ::grpc::Status Bind(grpc_impl::ServerContext *context, + const ::posix_server::BindRequest *request, + ::posix_server::BindResponse *response) override { + if (!request->has_addr()) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Missing address"); + } + + sockaddr_storage addr; + socklen_t addr_len; + auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len); + if (!err.ok()) { + return err; + } + + response->set_ret( + bind(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Close(grpc_impl::ServerContext *context, + const ::posix_server::CloseRequest *request, + ::posix_server::CloseResponse *response) override { + response->set_ret(close(request->fd())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Connect(grpc_impl::ServerContext *context, + const ::posix_server::ConnectRequest *request, + ::posix_server::ConnectResponse *response) override { + if (!request->has_addr()) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Missing address"); + } + sockaddr_storage addr; + socklen_t addr_len; + auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len); + if (!err.ok()) { + return err; + } + + response->set_ret(connect(request->sockfd(), + reinterpret_cast<sockaddr *>(&addr), addr_len)); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Fcntl(grpc_impl::ServerContext *context, + const ::posix_server::FcntlRequest *request, + ::posix_server::FcntlResponse *response) override { + response->set_ret(::fcntl(request->fd(), request->cmd(), request->arg())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status GetSockName( + grpc_impl::ServerContext *context, + const ::posix_server::GetSockNameRequest *request, + ::posix_server::GetSockNameResponse *response) override { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + response->set_ret(getsockname( + request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); + response->set_errno_(errno); + return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); + } + + ::grpc::Status GetSockOpt( + grpc_impl::ServerContext *context, + const ::posix_server::GetSockOptRequest *request, + ::posix_server::GetSockOptResponse *response) override { + switch (request->type()) { + case ::posix_server::GetSockOptRequest::BYTES: { + socklen_t optlen = request->optlen(); + std::vector<char> buf(optlen); + response->set_ret(::getsockopt(request->sockfd(), request->level(), + request->optname(), buf.data(), + &optlen)); + if (optlen >= 0) { + response->mutable_optval()->set_bytesval(buf.data(), optlen); + } + break; + } + case ::posix_server::GetSockOptRequest::INT: { + int intval = 0; + socklen_t optlen = sizeof(intval); + response->set_ret(::getsockopt(request->sockfd(), request->level(), + request->optname(), &intval, &optlen)); + response->mutable_optval()->set_intval(intval); + break; + } + case ::posix_server::GetSockOptRequest::TIME: { + timeval tv; + socklen_t optlen = sizeof(tv); + response->set_ret(::getsockopt(request->sockfd(), request->level(), + request->optname(), &tv, &optlen)); + response->mutable_optval()->mutable_timeval()->set_seconds(tv.tv_sec); + response->mutable_optval()->mutable_timeval()->set_microseconds( + tv.tv_usec); + break; + } + default: + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Unknown SockOpt Type"); + } + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Listen(grpc_impl::ServerContext *context, + const ::posix_server::ListenRequest *request, + ::posix_server::ListenResponse *response) override { + response->set_ret(listen(request->sockfd(), request->backlog())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Send(::grpc::ServerContext *context, + const ::posix_server::SendRequest *request, + ::posix_server::SendResponse *response) override { + response->set_ret(::send(request->sockfd(), request->buf().data(), + request->buf().size(), request->flags())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status SendTo(::grpc::ServerContext *context, + const ::posix_server::SendToRequest *request, + ::posix_server::SendToResponse *response) override { + if (!request->has_dest_addr()) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Missing address"); + } + sockaddr_storage addr; + socklen_t addr_len; + auto err = proto_to_sockaddr(request->dest_addr(), &addr, &addr_len); + if (!err.ok()) { + return err; + } + + response->set_ret(::sendto(request->sockfd(), request->buf().data(), + request->buf().size(), request->flags(), + reinterpret_cast<sockaddr *>(&addr), addr_len)); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status SetSockOpt( + grpc_impl::ServerContext *context, + const ::posix_server::SetSockOptRequest *request, + ::posix_server::SetSockOptResponse *response) override { + switch (request->optval().val_case()) { + case ::posix_server::SockOptVal::kBytesval: + response->set_ret(setsockopt(request->sockfd(), request->level(), + request->optname(), + request->optval().bytesval().c_str(), + request->optval().bytesval().size())); + break; + case ::posix_server::SockOptVal::kIntval: { + int opt = request->optval().intval(); + response->set_ret(::setsockopt(request->sockfd(), request->level(), + request->optname(), &opt, sizeof(opt))); + break; + } + case ::posix_server::SockOptVal::kTimeval: { + timeval tv = {.tv_sec = static_cast<__time_t>( + request->optval().timeval().seconds()), + .tv_usec = static_cast<__suseconds_t>( + request->optval().timeval().microseconds())}; + response->set_ret(setsockopt(request->sockfd(), request->level(), + request->optname(), &tv, sizeof(tv))); + break; + } + default: + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Unknown SockOpt Type"); + } + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Socket(grpc_impl::ServerContext *context, + const ::posix_server::SocketRequest *request, + ::posix_server::SocketResponse *response) override { + response->set_fd( + socket(request->domain(), request->type(), request->protocol())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Recv(::grpc::ServerContext *context, + const ::posix_server::RecvRequest *request, + ::posix_server::RecvResponse *response) override { + std::vector<char> buf(request->len()); + response->set_ret( + recv(request->sockfd(), buf.data(), buf.size(), request->flags())); + if (response->ret() >= 0) { + response->set_buf(buf.data(), response->ret()); + } + response->set_errno_(errno); + return ::grpc::Status::OK; + } +}; + +// Parse command line options. Returns a pointer to the first argument beyond +// the options. +void parse_command_line_options(int argc, char *argv[], std::string *ip, + int *port) { + static struct option options[] = {{"ip", required_argument, NULL, 1}, + {"port", required_argument, NULL, 2}, + {0, 0, 0, 0}}; + + // Parse the arguments. + int c; + while ((c = getopt_long(argc, argv, "", options, NULL)) > 0) { + if (c == 1) { + *ip = optarg; + } else if (c == 2) { + *port = std::stoi(std::string(optarg)); + } + } +} + +void run_server(const std::string &ip, int port) { + PosixImpl posix_service; + grpc::ServerBuilder builder; + std::string server_address = ip + ":" + std::to_string(port); + // Set the authentication mechanism. + std::shared_ptr<grpc::ServerCredentials> creds = + grpc::InsecureServerCredentials(); + builder.AddListeningPort(server_address, creds); + builder.RegisterService(&posix_service); + + std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); + std::cerr << "Server listening on " << server_address << std::endl; + server->Wait(); + std::cerr << "posix_server is finished." << std::endl; +} + +int main(int argc, char *argv[]) { + std::cerr << "posix_server is starting." << std::endl; + std::string ip; + int port; + parse_command_line_options(argc, argv, &ip, &port); + + std::cerr << "Got IP " << ip << " and port " << port << "." << std::endl; + run_server(ip, port); +} diff --git a/test/packetimpact/netdevs/BUILD b/test/packetimpact/netdevs/BUILD new file mode 100644 index 000000000..422bb9b0c --- /dev/null +++ b/test/packetimpact/netdevs/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package( + licenses = ["notice"], +) + +go_library( + name = "netdevs", + srcs = ["netdevs.go"], + visibility = ["//test/packetimpact:__subpackages__"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + ], +) diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go new file mode 100644 index 000000000..d2c9cfeaf --- /dev/null +++ b/test/packetimpact/netdevs/netdevs.go @@ -0,0 +1,104 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package netdevs contains utilities for working with network devices. +package netdevs + +import ( + "fmt" + "net" + "regexp" + "strings" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// A DeviceInfo represents a network device. +type DeviceInfo struct { + MAC net.HardwareAddr + IPv4Addr net.IP + IPv4Net *net.IPNet + IPv6Addr net.IP + IPv6Net *net.IPNet +} + +var ( + deviceLine = regexp.MustCompile(`^\s*\d+: (\w+)`) + linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`) + inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`) + inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`) +) + +// ParseDevices parses the output from `ip addr show` into a map from device +// name to information about the device. +func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) { + var currentDevice string + var currentInfo DeviceInfo + deviceInfos := make(map[string]DeviceInfo) + for _, line := range strings.Split(cmdOutput, "\n") { + if m := deviceLine.FindStringSubmatch(line); m != nil { + if currentDevice != "" { + deviceInfos[currentDevice] = currentInfo + } + currentInfo = DeviceInfo{} + currentDevice = m[1] + } else if m := linkLine.FindStringSubmatch(line); m != nil { + mac, err := net.ParseMAC(m[1]) + if err != nil { + return nil, err + } + currentInfo.MAC = mac + } else if m := inetLine.FindStringSubmatch(line); m != nil { + ipv4Addr, ipv4Net, err := net.ParseCIDR(m[1]) + if err != nil { + return nil, err + } + currentInfo.IPv4Addr = ipv4Addr + currentInfo.IPv4Net = ipv4Net + } else if m := inet6Line.FindStringSubmatch(line); m != nil { + ipv6Addr, ipv6Net, err := net.ParseCIDR(m[1]) + if err != nil { + return nil, err + } + currentInfo.IPv6Addr = ipv6Addr + currentInfo.IPv6Net = ipv6Net + } + } + if currentDevice != "" { + deviceInfos[currentDevice] = currentInfo + } + return deviceInfos, nil +} + +// MACToIP converts the MAC address to an IPv6 link local address as described +// in RFC 4291 page 20: https://tools.ietf.org/html/rfc4291#page-20 +func MACToIP(mac net.HardwareAddr) net.IP { + addr := make([]byte, header.IPv6AddressSize) + addr[0] = 0xfe + addr[1] = 0x80 + header.EthernetAdddressToModifiedEUI64IntoBuf(tcpip.LinkAddress(mac), addr[8:]) + return net.IP(addr) +} + +// FindDeviceByIP finds a DeviceInfo and device name from an IP address in the +// output of ParseDevices. +func FindDeviceByIP(ip net.IP, devices map[string]DeviceInfo) (string, DeviceInfo, error) { + for dev, info := range devices { + if info.IPv4Addr.Equal(ip) { + return dev, info, nil + } + } + return "", DeviceInfo{}, fmt.Errorf("can't find %s on any interface", ip) +} diff --git a/test/packetimpact/proto/BUILD b/test/packetimpact/proto/BUILD new file mode 100644 index 000000000..4a4370f42 --- /dev/null +++ b/test/packetimpact/proto/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "proto_library") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +proto_library( + name = "posix_server", + srcs = ["posix_server.proto"], + has_services = 1, +) diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto new file mode 100644 index 000000000..ccd20b10d --- /dev/null +++ b/test/packetimpact/proto/posix_server.proto @@ -0,0 +1,230 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package posix_server; + +message SockaddrIn { + int32 family = 1; + uint32 port = 2; + bytes addr = 3; +} + +message SockaddrIn6 { + uint32 family = 1; + uint32 port = 2; + uint32 flowinfo = 3; + bytes addr = 4; + uint32 scope_id = 5; +} + +message Sockaddr { + oneof sockaddr { + SockaddrIn in = 1; + SockaddrIn6 in6 = 2; + } +} + +message Timeval { + int64 seconds = 1; + int64 microseconds = 2; +} + +message SockOptVal { + oneof val { + bytes bytesval = 1; + int32 intval = 2; + Timeval timeval = 3; + } +} + +// Request and Response pairs for each Posix service RPC call, sorted. + +message AcceptRequest { + int32 sockfd = 1; +} + +message AcceptResponse { + int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + Sockaddr addr = 3; +} + +message BindRequest { + int32 sockfd = 1; + Sockaddr addr = 2; +} + +message BindResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message CloseRequest { + int32 fd = 1; +} + +message CloseResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message ConnectRequest { + int32 sockfd = 1; + Sockaddr addr = 2; +} + +message ConnectResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message FcntlRequest { + int32 fd = 1; + int32 cmd = 2; + int32 arg = 3; +} + +message FcntlResponse { + int32 ret = 1; + int32 errno_ = 2; +} + +message GetSockNameRequest { + int32 sockfd = 1; +} + +message GetSockNameResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + Sockaddr addr = 3; +} + +message GetSockOptRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + int32 optlen = 4; + enum SockOptType { + UNSPECIFIED = 0; + BYTES = 1; + INT = 2; + TIME = 3; + } + SockOptType type = 5; +} + +message GetSockOptResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + SockOptVal optval = 3; +} + +message ListenRequest { + int32 sockfd = 1; + int32 backlog = 2; +} + +message ListenResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SendRequest { + int32 sockfd = 1; + bytes buf = 2; + int32 flags = 3; +} + +message SendResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SendToRequest { + int32 sockfd = 1; + bytes buf = 2; + int32 flags = 3; + Sockaddr dest_addr = 4; +} + +message SendToResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SetSockOptRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + SockOptVal optval = 4; +} + +message SetSockOptResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SocketRequest { + int32 domain = 1; + int32 type = 2; + int32 protocol = 3; +} + +message SocketResponse { + int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message RecvRequest { + int32 sockfd = 1; + int32 len = 2; + int32 flags = 3; +} + +message RecvResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + bytes buf = 3; +} + +service Posix { + // Call accept() on the DUT. + rpc Accept(AcceptRequest) returns (AcceptResponse); + // Call bind() on the DUT. + rpc Bind(BindRequest) returns (BindResponse); + // Call close() on the DUT. + rpc Close(CloseRequest) returns (CloseResponse); + // Call connect() on the DUT. + rpc Connect(ConnectRequest) returns (ConnectResponse); + // Call fcntl() on the DUT. + rpc Fcntl(FcntlRequest) returns (FcntlResponse); + // Call getsockname() on the DUT. + rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse); + // Call getsockopt() on the DUT. + rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse); + // Call listen() on the DUT. + rpc Listen(ListenRequest) returns (ListenResponse); + // Call send() on the DUT. + rpc Send(SendRequest) returns (SendResponse); + // Call sendto() on the DUT. + rpc SendTo(SendToRequest) returns (SendToResponse); + // Call setsockopt() on the DUT. + rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse); + // Call socket() on the DUT. + rpc Socket(SocketRequest) returns (SocketResponse); + // Call recv() on the DUT. + rpc Recv(RecvRequest) returns (RecvResponse); +} diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD new file mode 100644 index 000000000..bad4f0183 --- /dev/null +++ b/test/packetimpact/runner/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +go_test( + name = "packetimpact_test", + srcs = ["packetimpact_test.go"], + tags = [ + # Not intended to be run directly. + "local", + "manual", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/packetimpact/netdevs", + "@com_github_docker_docker//api/types/mount:go_default_library", + ], +) diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl new file mode 100644 index 000000000..77cdfea12 --- /dev/null +++ b/test/packetimpact/runner/defs.bzl @@ -0,0 +1,136 @@ +"""Defines rules for packetimpact test targets.""" + +load("//tools:defs.bzl", "go_test") + +def _packetimpact_test_impl(ctx): + test_runner = ctx.executable._test_runner + bench = ctx.actions.declare_file("%s-bench" % ctx.label.name) + bench_content = "\n".join([ + "#!/bin/bash", + # This test will run part in a distinct user namespace. This can cause + # permission problems, because all runfiles may not be owned by the + # current user, and no other users will be mapped in that namespace. + # Make sure that everything is readable here. + "find . -type f -or -type d -exec chmod a+rx {} \\;", + "%s %s --testbench_binary %s $@\n" % ( + test_runner.short_path, + " ".join(ctx.attr.flags), + ctx.files.testbench_binary[0].short_path, + ), + ]) + ctx.actions.write(bench, bench_content, is_executable = True) + + transitive_files = [] + if hasattr(ctx.attr._test_runner, "data_runfiles"): + transitive_files.append(ctx.attr._test_runner.data_runfiles.files) + runfiles = ctx.runfiles( + files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary, + transitive_files = depset(transitive = transitive_files), + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = bench, runfiles = runfiles)] + +_packetimpact_test = rule( + attrs = { + "_test_runner": attr.label( + executable = True, + cfg = "target", + default = ":packetimpact_test", + ), + "_posix_server_binary": attr.label( + cfg = "target", + default = "//test/packetimpact/dut:posix_server", + ), + "testbench_binary": attr.label( + cfg = "target", + mandatory = True, + ), + "flags": attr.string_list( + mandatory = False, + default = [], + ), + }, + test = True, + implementation = _packetimpact_test_impl, +) + +PACKETIMPACT_TAGS = ["local", "manual"] + +def packetimpact_linux_test( + name, + testbench_binary, + expect_failure = False, + **kwargs): + """Add a packetimpact test on linux. + + Args: + name: name of the test + testbench_binary: the testbench binary + expect_failure: the test must fail + **kwargs: all the other args, forwarded to _packetimpact_test + """ + expect_failure_flag = ["--expect_failure"] if expect_failure else [] + _packetimpact_test( + name = name + "_linux_test", + testbench_binary = testbench_binary, + flags = ["--dut_platform", "linux"] + expect_failure_flag, + tags = PACKETIMPACT_TAGS + ["packetimpact"], + **kwargs + ) + +def packetimpact_netstack_test( + name, + testbench_binary, + expect_failure = False, + **kwargs): + """Add a packetimpact test on netstack. + + Args: + name: name of the test + testbench_binary: the testbench binary + expect_failure: the test must fail + **kwargs: all the other args, forwarded to _packetimpact_test + """ + expect_failure_flag = [] + if expect_failure: + expect_failure_flag = ["--expect_failure"] + _packetimpact_test( + name = name + "_netstack_test", + testbench_binary = testbench_binary, + # This is the default runtime unless + # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. + flags = ["--dut_platform", "netstack", "--runtime=runsc-d"] + expect_failure_flag, + tags = PACKETIMPACT_TAGS + ["packetimpact"], + **kwargs + ) + +def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure = False, expect_netstack_failure = False, **kwargs): + """Add packetimpact tests written in go. + + Args: + name: name of the test + size: size of the test + pure: make a static go binary + expect_linux_failure: the test must fail for Linux + expect_netstack_failure: the test must fail for Netstack + **kwargs: all the other args, forwarded to go_test + """ + testbench_binary = name + "_test" + go_test( + name = testbench_binary, + size = size, + pure = pure, + tags = PACKETIMPACT_TAGS, + **kwargs + ) + packetimpact_linux_test( + name = name, + expect_failure = expect_linux_failure, + testbench_binary = testbench_binary, + ) + packetimpact_netstack_test( + name = name, + expect_failure = expect_netstack_failure, + testbench_binary = testbench_binary, + ) diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go new file mode 100644 index 000000000..9290d5112 --- /dev/null +++ b/test/packetimpact/runner/packetimpact_test.go @@ -0,0 +1,370 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The runner starts docker containers and networking for a packetimpact test. +package packetimpact_test + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net" + "os" + "os/exec" + "path" + "strings" + "testing" + "time" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/packetimpact/netdevs" +) + +// stringList implements flag.Value. +type stringList []string + +// String implements flag.Value.String. +func (l *stringList) String() string { + return strings.Join(*l, ",") +} + +// Set implements flag.Value.Set. +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +var ( + dutPlatform = flag.String("dut_platform", "", "either \"linux\" or \"netstack\"") + testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary") + tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump") + extraTestArgs = stringList{} + expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run") + + dutAddr = net.IPv4(0, 0, 0, 10) + testbenchAddr = net.IPv4(0, 0, 0, 20) +) + +const ctrlPort = "40000" + +// logger implements testutil.Logger. +// +// Labels logs based on their source and formats multi-line logs. +type logger string + +// Name implements testutil.Logger.Name. +func (l logger) Name() string { + return string(l) +} + +// Logf implements testutil.Logger.Logf. +func (l logger) Logf(format string, args ...interface{}) { + lines := strings.Split(fmt.Sprintf(format, args...), "\n") + log.Printf("%s: %s", l, lines[0]) + for _, line := range lines[1:] { + log.Printf("%*s %s", len(l), "", line) + } +} + +func TestOne(t *testing.T) { + flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") + flag.Parse() + if *dutPlatform != "linux" && *dutPlatform != "netstack" { + t.Fatal("--dut_platform should be either linux or netstack") + } + if *testbenchBinary == "" { + t.Fatal("--testbench_binary is missing") + } + if *dutPlatform == "netstack" { + if _, err := dockerutil.RuntimePath(); err != nil { + t.Fatal("--runtime is missing or invalid with --dut_platform=netstack:", err) + } + } + dockerutil.EnsureSupportedDockerVersion() + ctx := context.Background() + + // Create the networks needed for the test. One control network is needed for + // the gRPC control packets and one test network on which to transmit the test + // packets. + ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) + testNet := dockerutil.NewNetwork(ctx, logger("testNet")) + for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { + for { + if err := createDockerNetwork(ctx, dn); err != nil { + t.Log("creating docker network:", err) + const wait = 100 * time.Millisecond + t.Logf("sleeping %s and will try creating docker network again", wait) + // This can fail if another docker network claimed the same IP so we'll + // just try again. + time.Sleep(wait) + continue + } + break + } + defer func(dn *dockerutil.Network) { + if err := dn.Cleanup(ctx); err != nil { + t.Errorf("unable to cleanup container %s: %s", dn.Name, err) + } + }(dn) + // Sanity check. + inspect, err := dn.Inspect(ctx) + if err != nil { + t.Fatalf("failed to inspect network %s: %v", dn.Name, err) + } else if inspect.Name != dn.Name { + t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) + } + + } + + tmpDir, err := ioutil.TempDir("", "container-output") + if err != nil { + t.Fatal("creating temp dir:", err) + } + defer os.RemoveAll(tmpDir) + + const testOutputDir = "/tmp/testoutput" + + // Create the Docker container for the DUT. + dut := dockerutil.MakeContainer(ctx, logger("dut")) + if *dutPlatform == "linux" { + dut.Runtime = "" + } + + runOpts := dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{mount.Mount{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, + } + + const containerPosixServerBinary = "/packetimpact/posix_server" + dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server") + + conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort) + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("unable to create container %s: %v", dut.Name, err) + } + + defer dut.CleanUp(ctx) + + // Add ctrlNet as eth1 and testNet as eth2. + const testNetDev = "eth2" + if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { + t.Fatal(err) + } + + if err := dut.Start(ctx); err != nil { + t.Fatalf("unable to start container %s: %s", dut.Name, err) + } + + if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { + t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err) + } + + dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + + remoteMAC := dutDeviceInfo.MAC + remoteIPv6 := dutDeviceInfo.IPv6Addr + // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if + // needed. + if remoteIPv6 == nil { + if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { + t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err) + } + // Now try again, to make sure that it worked. + _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + remoteIPv6 = dutDeviceInfo.IPv6Addr + if remoteIPv6 == nil { + t.Fatal("unable to set IPv6 address on container", dut.Name) + } + } + + // Create the Docker container for the testbench. + testbench := dockerutil.MakeContainer(ctx, logger("testbench")) + testbench.Runtime = "" // The testbench always runs on Linux. + + tbb := path.Base(*testbenchBinary) + containerTestbenchBinary := "/packetimpact/" + tbb + runOpts = dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{mount.Mount{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, + } + testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb) + + // Run tcpdump in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs := []string{ + "tcpdump", + "-S", "-vvv", "-U", "-n", + "-i", testNetDev, + "-w", testOutputDir + "/dump.pcap", + } + snifferRegex := "tcpdump: listening.*\n" + if *tshark { + // Run tshark in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs = []string{ + "tshark", "-V", "-l", "-n", "-i", testNetDev, + "-o", "tcp.check_checksum:TRUE", + "-o", "udp.check_checksum:TRUE", + } + snifferRegex = "Capturing on.*\n" + } + + defer func() { + if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { + t.Error("unable to copy container output files:", err) + } + }() + + conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...) + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("unable to create container %s: %s", testbench.Name, err) + } + defer testbench.CleanUp(ctx) + + // Add ctrlNet as eth1 and testNet as eth2. + if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { + t.Fatal(err) + } + + if err := testbench.Start(ctx); err != nil { + t.Fatalf("unable to start container %s: %s", testbench.Name, err) + } + + // Kill so that it will flush output. + defer func() { + time.Sleep(1 * time.Second) + testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) + }() + + if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { + t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) + } + + // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it + // will issue a RST. To prevent this IPtables can be used to filter out all + // incoming packets. The raw socket that packetimpact tests use will still see + // everything. + if _, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, "iptables", "-A", "INPUT", "-i", testNetDev, "-j", "DROP"); err != nil { + t.Fatalf("unable to Exec iptables on container %s: %s", testbench.Name, err) + } + + // FIXME(b/156449515): Some piece of the system has a race. The old + // bash script version had a sleep, so we have one too. The race should + // be fixed and this sleep removed. + time.Sleep(time.Second) + + // Start a packetimpact test on the test bench. The packetimpact test sends + // and receives packets and also sends POSIX socket commands to the + // posix_server to be executed on the DUT. + testArgs := []string{containerTestbenchBinary} + testArgs = append(testArgs, extraTestArgs...) + testArgs = append(testArgs, + "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(), + "--posix_server_port", ctrlPort, + "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(), + "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(), + "--remote_ipv6", remoteIPv6.String(), + "--remote_mac", remoteMAC.String(), + "--device", testNetDev, + "--dut_type", *dutPlatform, + ) + _, err = testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) + if !*expectFailure && err != nil { + t.Fatal("test failed:", err) + } + if *expectFailure && err == nil { + t.Fatal("test failure expected but the test succeeded, enable the test and mark the corresponding bug as fixed") + } +} + +func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { + for _, dn := range networks { + ip := addressInSubnet(addr, *dn.Subnet) + // Connect to the network with the specified IP address. + if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { + return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) + } + } + return nil +} + +// addressInSubnet combines the subnet provided with the address and returns a +// new address. The return address bits come from the subnet where the mask is 1 +// and from the ip address where the mask is 0. +func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP { + var octets []byte + for i := 0; i < 4; i++ { + octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) + } + return net.IP(octets) +} + +// createDockerNetwork makes a randomly-named network that will start with the +// namePrefix. The network will be a random /24 subnet. +func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { + randSource := rand.NewSource(time.Now().UnixNano()) + r1 := rand.New(randSource) + // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) + n.Subnet = &net.IPNet{ + IP: ip, + Mask: ip.DefaultMask(), + } + return n.Create(ctx) +} + +// deviceByIP finds a deviceInfo and device name from an IP address. +func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err) + } + devs, err := netdevs.ParseDevices(out) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err) + } + testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) + } + return testDevice, deviceInfo, nil +} diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD new file mode 100644 index 000000000..d19ec07d4 --- /dev/null +++ b/test/packetimpact/testbench/BUILD @@ -0,0 +1,46 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +go_library( + name = "testbench", + srcs = [ + "connections.go", + "dut.go", + "dut_client.go", + "layers.go", + "rawsockets.go", + "testbench.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//pkg/usermem", + "//test/packetimpact/netdevs", + "//test/packetimpact/proto:posix_server_go_proto", + "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_mohae_deepcopy//:go_default_library", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//keepalive:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + "@org_uber_go_multierr//:go_default_library", + ], +) + +go_test( + name = "testbench_test", + size = "small", + srcs = ["layers_test.go"], + library = ":testbench", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "@com_github_mohae_deepcopy//:go_default_library", + ], +) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go new file mode 100644 index 000000000..8b4a4d905 --- /dev/null +++ b/test/packetimpact/testbench/connections.go @@ -0,0 +1,950 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testbench has utilities to send and receive packets and also command +// the DUT to run POSIX functions. +package testbench + +import ( + "fmt" + "math/rand" + "net" + "testing" + "time" + + "github.com/mohae/deepcopy" + "go.uber.org/multierr" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { + switch sa := sa.(type) { + case *unix.SockaddrInet4: + return uint16(sa.Port), nil + case *unix.SockaddrInet6: + return uint16(sa.Port), nil + } + return 0, fmt.Errorf("sockaddr type %T does not contain port", sa) +} + +// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with +// the port if there is no error. +func pickPort(domain, typ int) (int, uint16, error) { + fd, err := unix.Socket(domain, typ, 0) + if err != nil { + return -1, 0, err + } + defer func() { + if err != nil { + err = multierr.Append(err, unix.Close(fd)) + } + }() + var sa unix.Sockaddr + switch domain { + case unix.AF_INET: + var sa4 unix.SockaddrInet4 + copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4()) + sa = &sa4 + case unix.AF_INET6: + var sa6 unix.SockaddrInet6 + copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16()) + sa = &sa6 + default: + return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain) + } + if err = unix.Bind(fd, sa); err != nil { + return -1, 0, err + } + sa, err = unix.Getsockname(fd) + if err != nil { + return -1, 0, err + } + port, err := portFromSockaddr(sa) + if err != nil { + return -1, 0, err + } + return fd, port, nil +} + +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. It should not + // update layerState, that is done in layerState.sent. + outgoing() Layer + + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. It + // should not update layerState, that is done in layerState.received. The + // caller takes ownership of the returned Layer. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error +} + +// etherState maintains state about an Ethernet connection. +type etherState struct { + out, in Ether +} + +var _ layerState = (*etherState)(nil) + +// newEtherState creates a new etherState. +func newEtherState(out, in Ether) (*etherState, error) { + lMAC, err := tcpip.ParseMACAddress(LocalMAC) + if err != nil { + return nil, fmt.Errorf("parsing local MAC: %q: %w", LocalMAC, err) + } + + rMAC, err := tcpip.ParseMACAddress(RemoteMAC) + if err != nil { + return nil, fmt.Errorf("parsing remote MAC: %q: %w", RemoteMAC, err) + } + s := etherState{ + out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, + in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *etherState) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +// incoming implements layerState.incoming. +func (s *etherState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*etherState) sent(Layer) error { + return nil +} + +func (*etherState) received(Layer) error { + return nil +} + +func (*etherState) close() error { + return nil +} + +// ipv4State maintains state about an IPv4 connection. +type ipv4State struct { + out, in IPv4 +} + +var _ layerState = (*ipv4State)(nil) + +// newIPv4State creates a new ipv4State. +func newIPv4State(out, in IPv4) (*ipv4State, error) { + lIP := tcpip.Address(net.ParseIP(LocalIPv4).To4()) + rIP := tcpip.Address(net.ParseIP(RemoteIPv4).To4()) + s := ipv4State{ + out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *ipv4State) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +// incoming implements layerState.incoming. +func (s *ipv4State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*ipv4State) sent(Layer) error { + return nil +} + +func (*ipv4State) received(Layer) error { + return nil +} + +func (*ipv4State) close() error { + return nil +} + +// ipv6State maintains state about an IPv6 connection. +type ipv6State struct { + out, in IPv6 +} + +var _ layerState = (*ipv6State)(nil) + +// newIPv6State creates a new ipv6State. +func newIPv6State(out, in IPv6) (*ipv6State, error) { + lIP := tcpip.Address(net.ParseIP(LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(RemoteIPv6).To16()) + s := ipv6State{ + out: IPv6{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv6{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +// outgoing returns an outgoing layer to be sent in a frame. +func (s *ipv6State) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +func (s *ipv6State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (s *ipv6State) sent(Layer) error { + // Nothing to do. + return nil +} + +func (s *ipv6State) received(Layer) error { + // Nothing to do. + return nil +} + +// close cleans up any resources held. +func (s *ipv6State) close() error { + return nil +} + +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} + +var _ layerState = (*tcpState)(nil) + +// SeqNumValue is a helper routine that allocates a new seqnum.Value value to +// store v and returns a pointer to it. +func SeqNumValue(v seqnum.Value) *seqnum.Value { + return &v +} + +// newTCPState creates a new TCPState. +func newTCPState(domain int, out, in TCP) (*tcpState, error) { + portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM) + if err != nil { + return nil, err + } + s := tcpState{ + out: TCP{SrcPort: &localPort}, + in: TCP{DstPort: &localPort}, + localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), + portPickerFD: portPickerFD, + finSent: false, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *tcpState) outgoing() Layer { + newOutgoing := deepcopy.Copy(s.out).(TCP) + if s.localSeqNum != nil { + newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) + } + if s.remoteSeqNum != nil { + newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) + } + return &newOutgoing +} + +// incoming implements layerState.incoming. +func (s *tcpState) incoming(received Layer) Layer { + tcpReceived, ok := received.(*TCP) + if !ok { + return nil + } + newIn := deepcopy.Copy(s.in).(TCP) + if s.remoteSeqNum != nil { + newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) + } + if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 { + // The caller didn't specify an AckNum so we'll expect the calculated one, + // but only if the ACK flag is set because the AckNum is not valid in a + // header if ACK is not set. + newIn.AckNum = Uint32(uint32(*s.localSeqNum)) + } + return &newIn +} + +func (s *tcpState) sent(sent Layer) error { + tcp, ok := sent.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", sent) + } + if !s.finSent { + // update localSeqNum by the payload only when FIN is not yet sent by us + for current := tcp.next(); current != nil; current = current.next() { + s.localSeqNum.UpdateForward(seqnum.Size(current.length())) + } + } + if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.localSeqNum.UpdateForward(1) + } + if *tcp.Flags&(header.TCPFlagFin) != 0 { + s.finSent = true + } + return nil +} + +func (s *tcpState) received(l Layer) error { + tcp, ok := l.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", l) + } + s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) + if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.remoteSeqNum.UpdateForward(1) + } + for current := tcp.next(); current != nil; current = current.next() { + s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) + } + return nil +} + +// close frees the port associated with this connection. +func (s *tcpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err + } + s.portPickerFD = -1 + return nil +} + +// udpState maintains state about a UDP connection. +type udpState struct { + out, in UDP + portPickerFD int +} + +var _ layerState = (*udpState)(nil) + +// newUDPState creates a new udpState. +func newUDPState(domain int, out, in UDP) (*udpState, error) { + portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM) + if err != nil { + return nil, err + } + s := udpState{ + out: UDP{SrcPort: &localPort}, + in: UDP{DstPort: &localPort}, + portPickerFD: portPickerFD, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *udpState) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +// incoming implements layerState.incoming. +func (s *udpState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*udpState) sent(l Layer) error { + return nil +} + +func (*udpState) received(l Layer) error { + return nil +} + +// close frees the port associated with this connection. +func (s *udpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err + } + s.portPickerFD = -1 + return nil +} + +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer + t *testing.T +} + +// Returns the default incoming frame against which to match. If received is +// longer than layerStates then that may still count as a match. The reverse is +// never a match and nil is returned. +func (conn *Connection) incoming(received Layers) Layers { + if len(received) < len(conn.layerStates) { + return nil + } + in := Layers{} + for i, s := range conn.layerStates { + toMatch := s.incoming(received[i]) + if toMatch == nil { + return nil + } + in = append(in, toMatch) + } + return in +} + +func (conn *Connection) match(override, received Layers) bool { + toMatch := conn.incoming(received) + if toMatch == nil { + return false // Not enough layers in gotLayers for matching. + } + if err := toMatch.merge(override); err != nil { + return false // Failing to merge is not matching. + } + return toMatch.match(received) +} + +// Close frees associated resources held by the Connection. +func (conn *Connection) Close() { + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) + for _, s := range conn.layerStates { + if err := s.close(); err != nil { + errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) + } + } + if errs != nil { + conn.t.Fatalf("unable to close %+v: %s", conn, errs) + } +} + +// CreateFrame builds a frame for the connection with defaults overriden +// from the innermost layer out, and additionalLayers added after it. +// +// Note that overrideLayers can have a length that is less than the number +// of layers in this connection, and in such cases the innermost layers are +// overriden first. As an example, valid values of overrideLayers for a TCP- +// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and +// [Ethernet, IPv4, TCP]. +func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers { + var layersToSend Layers + for i, s := range conn.layerStates { + layer := s.outgoing() + // overrideLayers and conn.layerStates have their tails aligned, so + // to find the index we move backwards by the distance i is to the + // end. + if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 { + if err := layer.merge(overrideLayers[j]); err != nil { + conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) + } + } + layersToSend = append(layersToSend, layer) + } + layersToSend = append(layersToSend, additionalLayers...) + return layersToSend +} + +// SendFrameStateless sends a frame without updating any of the layer states. +// +// This method is useful for sending out-of-band control messages such as +// ICMP packets, where it would not make sense to update the transport layer's +// state using the ICMP header. +func (conn *Connection) SendFrameStateless(frame Layers) { + outBytes, err := frame.ToBytes() + if err != nil { + conn.t.Fatalf("can't build outgoing packet: %s", err) + } + conn.injector.Send(outBytes) +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(frame Layers) { + outBytes, err := frame.ToBytes() + if err != nil { + conn.t.Fatalf("can't build outgoing packet: %s", err) + } + conn.injector.Send(outBytes) + + // frame might have nil values where the caller wanted to use default values. + // sentFrame will have no nil values in it because it comes from parsing the + // bytes that were actually sent. + sentFrame := parse(parseEther, outBytes) + // Update the state of each layer based on what was sent. + for i, s := range conn.layerStates { + if err := s.sent(sentFrame[i]); err != nil { + conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) + } + } +} + +// send sends a packet, possibly with layers of this connection overridden and +// additional layers added. +// +// Types defined with Connection as the underlying type should expose +// type-safe versions of this method. +func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) { + conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...)) +} + +// recvFrame gets the next successfully parsed frame (of type Layers) within the +// timeout provided. If no parsable frame arrives before the timeout, it returns +// nil. +func (conn *Connection) recvFrame(timeout time.Duration) Layers { + if timeout <= 0 { + return nil + } + b := conn.sniffer.Recv(timeout) + if b == nil { + return nil + } + return parse(parseEther, b) +} + +// layersError stores the Layers that we got and the Layers that we wanted to +// match. +type layersError struct { + got, want Layers +} + +func (e *layersError) Error() string { + return e.got.diff(e.want) +} + +// Expect expects a frame with the final layerStates layer matching the +// provided Layer within the timeout specified. If it doesn't arrive in time, +// an error is returned. +func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { + // Make a frame that will ignore all but the final layer. + layers := make([]Layer, len(conn.layerStates)) + layers[len(layers)-1] = layer + + gotFrame, err := conn.ExpectFrame(layers, timeout) + if err != nil { + return nil, err + } + if len(conn.layerStates)-1 < len(gotFrame) { + return gotFrame[len(conn.layerStates)-1], nil + } + conn.t.Fatal("the received frame should be at least as long as the expected layers") + panic("unreachable") +} + +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If one arrives in time, the Layers is returned without an +// error. If it doesn't arrive in time, it returns nil and error is non-nil. +func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { + deadline := time.Now().Add(timeout) + var errs error + for { + var gotLayers Layers + if timeout = time.Until(deadline); timeout > 0 { + gotLayers = conn.recvFrame(timeout) + } + if gotLayers == nil { + if errs == nil { + return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout) + } + return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs) + } + if conn.match(layers, gotLayers) { + for i, s := range conn.layerStates { + if err := s.received(gotLayers[i]); err != nil { + conn.t.Fatal(err) + } + } + return gotLayers, nil + } + errs = multierr.Combine(errs, &layersError{got: gotLayers, want: conn.incoming(gotLayers)}) + } +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain() { + conn.sniffer.Drain() +} + +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection + +// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. +func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, + } +} + +// Connect performs a TCP 3-way handshake. The input Connection should have a +// final TCP Layer. +func (conn *TCPIPv4) Connect() { + conn.t.Helper() + + // Send the SYN. + conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) + + // Wait for the SYN-ACK. + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck + + // Send an ACK. + conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) +} + +// ConnectWithOptions performs a TCP 3-way handshake with given TCP options. +// The input Connection should have a final TCP Layer. +func (conn *TCPIPv4) ConnectWithOptions(options []byte) { + conn.t.Helper() + + // Send the SYN. + conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + + // Wait for the SYN-ACK. + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck + + // Send an ACK. + conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(expected, timeout) +} + +// ExpectNextData attempts to receive the next incoming segment for the +// connection and expects that to match the given layers. +// +// It differs from ExpectData() in that here we are only interested in the next +// received segment, while ExpectData() can receive multiple segments for the +// connection until there is a match with given layers or a timeout. +func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + // Receive the first incoming TCP segment for this connection. + got, err := conn.ExpectData(&TCP{}, nil, timeout) + if err != nil { + return nil, err + } + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length())) + } + if !(*Connection)(conn).match(expected, got) { + return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) + } + return got, nil +} + +// Send a packet with reasonable defaults. Potentially override the TCP layer in +// the connection with the provided layer and add additionLayers. +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + (*Connection)(conn).send(Layers{&tcp}, additionalLayers...) +} + +// Close frees associated resources held by the TCPIPv4 connection. +func (conn *TCPIPv4) Close() { + (*Connection)(conn).Close() +} + +// Expect expects a frame with the TCP layer matching the provided TCP within +// the timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { + layer, err := (*Connection)(conn).Expect(&tcp, timeout) + if layer == nil { + return nil, err + } + gotTCP, ok := layer.(*TCP) + if !ok { + conn.t.Fatalf("expected %s to be TCP", layer) + } + return gotTCP, err +} + +func (conn *TCPIPv4) tcpState() *tcpState { + state, ok := conn.layerStates[2].(*tcpState) + if !ok { + conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) + } + return state +} + +func (conn *TCPIPv4) ipv4State() *ipv4State { + state, ok := conn.layerStates[1].(*ipv4State) + if !ok { + conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + } + return state +} + +// RemoteSeqNum returns the next expected sequence number from the DUT. +func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { + return conn.tcpState().remoteSeqNum +} + +// LocalSeqNum returns the next sequence number to send from the testbench. +func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { + return conn.tcpState().localSeqNum +} + +// SynAck returns the SynAck that was part of the handshake. +func (conn *TCPIPv4) SynAck() *TCP { + return conn.tcpState().synAck +} + +// LocalAddr gets the local socket address of this connection. +func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 { + sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) + return sa +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} + +// IPv6Conn maintains the state for all the layers in a IPv6 connection. +type IPv6Conn Connection + +// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults. +func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make EtherState: %s", err) + } + ipv6State, err := newIPv6State(outgoingIPv6, incomingIPv6) + if err != nil { + t.Fatalf("can't make IPv6State: %s", err) + } + + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return IPv6Conn{ + layerStates: []layerState{etherState, ipv6State}, + injector: injector, + sniffer: sniffer, + t: t, + } +} + +// Send sends a frame with ipv6 overriding the IPv6 layer defaults and +// additionalLayers added after it. +func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) { + (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...) +} + +// Close to clean up any resources held. +func (conn *IPv6Conn) Close() { + (*Connection)(conn).Close() +} + +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) { + return (*Connection)(conn).ExpectFrame(frame, timeout) +} + +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection + +// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. +func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP) + if err != nil { + t.Fatalf("can't make udpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return UDPIPv4{ + layerStates: []layerState{etherState, ipv4State, udpState}, + injector: injector, + sniffer: sniffer, + t: t, + } +} + +func (conn *UDPIPv4) udpState() *udpState { + state, ok := conn.layerStates[2].(*udpState) + if !ok { + conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + } + return state +} + +func (conn *UDPIPv4) ipv4State() *ipv4State { + state, ok := conn.layerStates[1].(*ipv4State) + if !ok { + conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + } + return state +} + +// LocalAddr gets the local socket address of this connection. +func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 { + sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) + return sa +} + +// Send sends a packet with reasonable defaults, potentially overriding the UDP +// layer and adding additionLayers. +func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) { + (*Connection)(conn).send(Layers{&udp}, additionalLayers...) +} + +// SendIP sends a packet with reasonable defaults, potentially overriding the +// UDP and IPv4 headers and adding additionLayers. +func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) { + (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...) +} + +// Expect expects a frame with the UDP layer matching the provided UDP within +// the timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) { + conn.t.Helper() + layer, err := (*Connection)(conn).Expect(&udp, timeout) + if layer == nil { + return nil, err + } + gotUDP, ok := layer.(*UDP) + if !ok { + conn.t.Fatalf("expected %s to be UDP", layer) + } + return gotUDP, err +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + conn.t.Helper() + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = &udp + if payload.length() != 0 { + expected = append(expected, &payload) + } + return (*Connection)(conn).ExpectFrame(expected, timeout) +} + +// Close frees associated resources held by the UDPIPv4 connection. +func (conn *UDPIPv4) Close() { + (*Connection)(conn).Close() +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain() { + conn.sniffer.Drain() +} diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go new file mode 100644 index 000000000..2a2afecb5 --- /dev/null +++ b/test/packetimpact/testbench/dut.go @@ -0,0 +1,658 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "context" + "flag" + "net" + "strconv" + "syscall" + "testing" + + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" + + "golang.org/x/sys/unix" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +// DUT communicates with the DUT to force it to make POSIX calls. +type DUT struct { + t *testing.T + conn *grpc.ClientConn + posixServer POSIXClient +} + +// NewDUT creates a new connection with the DUT over gRPC. +func NewDUT(t *testing.T) DUT { + flag.Parse() + if err := genPseudoFlags(); err != nil { + t.Fatal("generating psuedo flags:", err) + } + + posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort) + conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive})) + if err != nil { + t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err) + } + posixServer := NewPOSIXClient(conn) + return DUT{ + t: t, + conn: conn, + posixServer: posixServer, + } +} + +// TearDown closes the underlying connection. +func (dut *DUT) TearDown() { + dut.conn.Close() +} + +func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { + dut.t.Helper() + switch s := sa.(type) { + case *unix.SockaddrInet4: + return &pb.Sockaddr{ + Sockaddr: &pb.Sockaddr_In{ + In: &pb.SockaddrIn{ + Family: unix.AF_INET, + Port: uint32(s.Port), + Addr: s.Addr[:], + }, + }, + } + case *unix.SockaddrInet6: + return &pb.Sockaddr{ + Sockaddr: &pb.Sockaddr_In6{ + In6: &pb.SockaddrIn6{ + Family: unix.AF_INET6, + Port: uint32(s.Port), + Flowinfo: 0, + ScopeId: s.ZoneId, + Addr: s.Addr[:], + }, + }, + } + } + dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + return nil +} + +func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { + dut.t.Helper() + switch s := sa.Sockaddr.(type) { + case *pb.Sockaddr_In: + ret := unix.SockaddrInet4{ + Port: int(s.In.GetPort()), + } + copy(ret.Addr[:], s.In.GetAddr()) + return &ret + case *pb.Sockaddr_In6: + ret := unix.SockaddrInet6{ + Port: int(s.In6.GetPort()), + ZoneId: s.In6.GetScopeId(), + } + copy(ret.Addr[:], s.In6.GetAddr()) + } + dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + return nil +} + +// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol +// proto, and bound to the IP address addr. Returns the new file descriptor and +// the port that was selected on the DUT. +func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) { + dut.t.Helper() + var fd int32 + if addr.To4() != nil { + fd = dut.Socket(unix.AF_INET, typ, proto) + sa := unix.SockaddrInet4{} + copy(sa.Addr[:], addr.To4()) + dut.Bind(fd, &sa) + } else if addr.To16() != nil { + fd = dut.Socket(unix.AF_INET6, typ, proto) + sa := unix.SockaddrInet6{} + copy(sa.Addr[:], addr.To16()) + dut.Bind(fd, &sa) + } else { + dut.t.Fatalf("unknown ip addr type for remoteIP") + } + sa := dut.GetSockName(fd) + var port int + switch s := sa.(type) { + case *unix.SockaddrInet4: + port = s.Port + case *unix.SockaddrInet6: + port = s.Port + default: + dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) + } + return fd, uint16(port) +} + +// CreateListener makes a new TCP connection. If it fails, the test ends. +func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { + fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4)) + dut.Listen(fd, backlog) + return fd, remotePort +} + +// All the functions that make gRPC calls to the POSIX service are below, sorted +// alphabetically. + +// Accept calls accept on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// AcceptWithErrno. +func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + fd, sa, err := dut.AcceptWithErrno(ctx, sockfd) + if fd < 0 { + dut.t.Fatalf("failed to accept: %s", err) + } + return fd, sa +} + +// AcceptWithErrno calls accept on the DUT. +func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { + dut.t.Helper() + req := pb.AcceptRequest{ + Sockfd: sockfd, + } + resp, err := dut.posixServer.Accept(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Accept: %s", err) + } + return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) +} + +// Bind calls bind on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is +// needed, use BindWithErrno. +func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.BindWithErrno(ctx, fd, sa) + if ret != 0 { + dut.t.Fatalf("failed to bind socket: %s", err) + } +} + +// BindWithErrno calls bind on the DUT. +func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { + dut.t.Helper() + req := pb.BindRequest{ + Sockfd: fd, + Addr: dut.sockaddrToProto(sa), + } + resp, err := dut.posixServer.Bind(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Bind: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Close calls close on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// CloseWithErrno. +func (dut *DUT) Close(fd int32) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.CloseWithErrno(ctx, fd) + if ret != 0 { + dut.t.Fatalf("failed to close: %s", err) + } +} + +// CloseWithErrno calls close on the DUT. +func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { + dut.t.Helper() + req := pb.CloseRequest{ + Fd: fd, + } + resp, err := dut.posixServer.Close(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Close: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Connect calls connect on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use ConnectWithErrno. +func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.ConnectWithErrno(ctx, fd, sa) + // Ignore 'operation in progress' error that can be returned when the socket + // is non-blocking. + if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { + dut.t.Fatalf("failed to connect socket: %s", err) + } +} + +// ConnectWithErrno calls bind on the DUT. +func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { + dut.t.Helper() + req := pb.ConnectRequest{ + Sockfd: fd, + Addr: dut.sockaddrToProto(sa), + } + resp, err := dut.posixServer.Connect(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Connect: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Fcntl calls fcntl on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use FcntlWithErrno. +func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg) + if ret == -1 { + dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) + } + return ret +} + +// FcntlWithErrno calls fcntl on the DUT. +func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) { + dut.t.Helper() + req := pb.FcntlRequest{ + Fd: fd, + Cmd: cmd, + Arg: arg, + } + resp, err := dut.posixServer.Fcntl(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Fcntl: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// GetSockName calls getsockname on the DUT and causes a fatal test failure if +// it doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockNameWithErrno. +func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd) + if ret != 0 { + dut.t.Fatalf("failed to getsockname: %s", err) + } + return sa +} + +// GetSockNameWithErrno calls getsockname on the DUT. +func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { + dut.t.Helper() + req := pb.GetSockNameRequest{ + Sockfd: sockfd, + } + resp, err := dut.posixServer.GetSockName(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Bind: %s", err) + } + return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) +} + +func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { + dut.t.Helper() + req := pb.GetSockOptRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Optlen: optlen, + Type: typ, + } + resp, err := dut.posixServer.GetSockOpt(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call GetSockOpt: %s", err) + } + optval := resp.GetOptval() + if optval == nil { + dut.t.Fatalf("GetSockOpt response does not contain a value") + } + return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_()) +} + +// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockOptWithErrno. Because endianess and the width of values +// might differ between the testbench and DUT architectures, prefer to use a +// more specific GetSockOptXxx function. +func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen) + if ret != 0 { + dut.t.Fatalf("failed to GetSockOpt: %s", err) + } + return optval +} + +// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the +// width of values might differ between the testbench and DUT architectures, +// prefer to use a more specific GetSockOptXxxWithErrno function. +func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) { + dut.t.Helper() + ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) + bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval) + if !ok { + dut.t.Fatalf("GetSockOpt got value type: %T, want bytes", optval) + } + return ret, bytesval.Bytesval, errno +} + +// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the int optval or error handling +// is needed, use GetSockOptIntWithErrno. +func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname) + if ret != 0 { + dut.t.Fatalf("failed to GetSockOptInt: %s", err) + } + return intval +} + +// GetSockOptIntWithErrno calls getsockopt with an integer optval. +func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) { + dut.t.Helper() + ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) + intval, ok := optval.Val.(*pb.SockOptVal_Intval) + if !ok { + dut.t.Fatalf("GetSockOpt got value type: %T, want int", optval) + } + return ret, intval.Intval, errno +} + +// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockOptTimevalWithErrno. +func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname) + if ret != 0 { + dut.t.Fatalf("failed to GetSockOptTimeval: %s", err) + } + return timeval +} + +// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval. +func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) { + dut.t.Helper() + ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) + tv, ok := optval.Val.(*pb.SockOptVal_Timeval) + if !ok { + dut.t.Fatalf("GetSockOpt got value type: %T, want timeval", optval) + } + timeval := unix.Timeval{ + Sec: tv.Timeval.Seconds, + Usec: tv.Timeval.Microseconds, + } + return ret, timeval, errno +} + +// Listen calls listen on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// ListenWithErrno. +func (dut *DUT) Listen(sockfd, backlog int32) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.ListenWithErrno(ctx, sockfd, backlog) + if ret != 0 { + dut.t.Fatalf("failed to listen: %s", err) + } +} + +// ListenWithErrno calls listen on the DUT. +func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) { + dut.t.Helper() + req := pb.ListenRequest{ + Sockfd: sockfd, + Backlog: backlog, + } + resp, err := dut.posixServer.Listen(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Listen: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Send calls send on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// SendWithErrno. +func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags) + if ret == -1 { + dut.t.Fatalf("failed to send: %s", err) + } + return ret +} + +// SendWithErrno calls send on the DUT. +func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) { + dut.t.Helper() + req := pb.SendRequest{ + Sockfd: sockfd, + Buf: buf, + Flags: flags, + } + resp, err := dut.posixServer.Send(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Send: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// SendToWithErrno. +func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr) + if ret == -1 { + dut.t.Fatalf("failed to sendto: %s", err) + } + return ret +} + +// SendToWithErrno calls sendto on the DUT. +func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { + dut.t.Helper() + req := pb.SendToRequest{ + Sockfd: sockfd, + Buf: buf, + Flags: flags, + DestAddr: dut.sockaddrToProto(destAddr), + } + resp, err := dut.posixServer.SendTo(ctx, &req) + if err != nil { + dut.t.Fatalf("faled to call SendTo: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking +// is true, otherwise it will clear the flag. +func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) { + dut.t.Helper() + flags := dut.Fcntl(fd, unix.F_GETFL, 0) + if nonblocking { + flags |= unix.O_NONBLOCK + } else { + flags &= ^unix.O_NONBLOCK + } + dut.Fcntl(fd, unix.F_SETFL, flags) +} + +func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { + dut.t.Helper() + req := pb.SetSockOptRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Optval: optval, + } + resp, err := dut.posixServer.SetSockOpt(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call SetSockOpt: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptWithErrno. Because endianess and the width of values +// might differ between the testbench and DUT architectures, prefer to use a +// more specific SetSockOptXxx function. +func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOpt: %s", err) + } +} + +// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the +// width of values might differ between the testbench and DUT architectures, +// prefer to use a more specific SetSockOptXxxWithErrno function. +func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) { + dut.t.Helper() + return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) +} + +// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the int optval or error handling +// is needed, use SetSockOptIntWithErrno. +func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOptInt: %s", err) + } +} + +// SetSockOptIntWithErrno calls setsockopt with an integer optval. +func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) { + dut.t.Helper() + return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) +} + +// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptTimevalWithErrno. +func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOptTimeval: %s", err) + } +} + +// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to +// bytes. +func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { + dut.t.Helper() + timeval := pb.Timeval{ + Seconds: int64(tv.Sec), + Microseconds: int64(tv.Usec), + } + return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) +} + +// Socket calls socket on the DUT and returns the file descriptor. If socket +// fails on the DUT, the test ends. +func (dut *DUT) Socket(domain, typ, proto int32) int32 { + dut.t.Helper() + fd, err := dut.SocketWithErrno(domain, typ, proto) + if fd < 0 { + dut.t.Fatalf("failed to create socket: %s", err) + } + return fd +} + +// SocketWithErrno calls socket on the DUT and returns the fd and errno. +func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { + dut.t.Helper() + req := pb.SocketRequest{ + Domain: domain, + Type: typ, + Protocol: proto, + } + ctx := context.Background() + resp, err := dut.posixServer.Socket(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Socket: %s", err) + } + return resp.GetFd(), syscall.Errno(resp.GetErrno_()) +} + +// Recv calls recv on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// RecvWithErrno. +func (dut *DUT) Recv(sockfd, len, flags int32) []byte { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) + if ret == -1 { + dut.t.Fatalf("failed to recv: %s", err) + } + return buf +} + +// RecvWithErrno calls recv on the DUT. +func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) { + dut.t.Helper() + req := pb.RecvRequest{ + Sockfd: sockfd, + Len: len, + Flags: flags, + } + resp, err := dut.posixServer.Recv(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Recv: %s", err) + } + return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) +} diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go new file mode 100644 index 000000000..d0e68c5da --- /dev/null +++ b/test/packetimpact/testbench/dut_client.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "google.golang.org/grpc" + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" +) + +// PosixClient is a gRPC client for the Posix service. +type POSIXClient pb.PosixClient + +// NewPOSIXClient makes a new gRPC client for the POSIX service. +func NewPOSIXClient(c grpc.ClientConnInterface) POSIXClient { + return pb.NewPosixClient(c) +} diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go new file mode 100644 index 000000000..a8121b0da --- /dev/null +++ b/test/packetimpact/testbench/layers.go @@ -0,0 +1,1384 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "encoding/hex" + "fmt" + "reflect" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "go.uber.org/multierr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// Layer is the interface that all encapsulations must implement. +// +// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A +// Layer contains all the fields of the encapsulation. Each field is a pointer +// and may be nil. +type Layer interface { + fmt.Stringer + + // ToBytes converts the Layer into bytes. In places where the Layer's field + // isn't nil, the value that is pointed to is used. When the field is nil, a + // reasonable default for the Layer is used. For example, "64" for IPv4 TTL + // and a calculated checksum for TCP or IP. Some layers require information + // from the previous or next layers in order to compute a default, such as + // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list + // to the layer's neighbors. + ToBytes() ([]byte, error) + + // match checks if the current Layer matches the provided Layer. If either + // Layer has a nil in a given field, that field is considered matching. + // Otherwise, the values pointed to by the fields must match. The LayerBase is + // ignored. + match(Layer) bool + + // length in bytes of the current encapsulation + length() int + + // next gets a pointer to the encapsulated Layer. + next() Layer + + // prev gets a pointer to the Layer encapsulating this one. + Prev() Layer + + // setNext sets the pointer to the encapsulated Layer. + setNext(Layer) + + // setPrev sets the pointer to the Layer encapsulating this one. + setPrev(Layer) + + // merge overrides the values in the interface with the provided values. + merge(Layer) error +} + +// LayerBase is the common elements of all layers. +type LayerBase struct { + nextLayer Layer + prevLayer Layer +} + +func (lb *LayerBase) next() Layer { + return lb.nextLayer +} + +// Prev returns the previous layer. +func (lb *LayerBase) Prev() Layer { + return lb.prevLayer +} + +func (lb *LayerBase) setNext(l Layer) { + lb.nextLayer = l +} + +func (lb *LayerBase) setPrev(l Layer) { + lb.prevLayer = l +} + +// equalLayer compares that two Layer structs match while ignoring field in +// which either input has a nil and also ignoring the LayerBase of the inputs. +func equalLayer(x, y Layer) bool { + if x == nil || y == nil { + return true + } + // opt ignores comparison pairs where either of the inputs is a nil. + opt := cmp.FilterValues(func(x, y interface{}) bool { + for _, l := range []interface{}{x, y} { + v := reflect.ValueOf(l) + if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() { + return true + } + } + return false + }, cmp.Ignore()) + return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) +} + +// mergeLayer merges y into x. Any fields for which y has a non-nil value, that +// value overwrite the corresponding fields in x. +func mergeLayer(x, y Layer) error { + if y == nil { + return nil + } + if reflect.TypeOf(x) != reflect.TypeOf(y) { + return fmt.Errorf("can't merge %T into %T", y, x) + } + vx := reflect.ValueOf(x).Elem() + vy := reflect.ValueOf(y).Elem() + t := vy.Type() + for i := 0; i < vy.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + v := vy.Field(i) + if v.IsNil() { + continue + } + vx.Field(i).Set(v) + } + return nil +} + +func stringLayer(l Layer) string { + v := reflect.ValueOf(l).Elem() + t := v.Type() + var ret []string + for i := 0; i < v.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + v := v.Field(i) + if v.IsNil() { + continue + } + v = reflect.Indirect(v) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) + } else { + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } + } + return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) +} + +// Ether can construct and match an ethernet encapsulation. +type Ether struct { + LayerBase + SrcAddr *tcpip.LinkAddress + DstAddr *tcpip.LinkAddress + Type *tcpip.NetworkProtocolNumber +} + +func (l *Ether) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *Ether) ToBytes() ([]byte, error) { + b := make([]byte, header.EthernetMinimumSize) + h := header.Ethernet(b) + fields := &header.EthernetFields{} + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + if l.Type != nil { + fields.Type = *l.Type + } else { + switch n := l.next().(type) { + case *IPv4: + fields.Type = header.IPv4ProtocolNumber + case *IPv6: + fields.Type = header.IPv6ProtocolNumber + default: + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) + } + } + h.Encode(fields) + return h, nil +} + +// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value +// to store v and returns a pointer to it. +func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress { + return &v +} + +// NetworkProtocolNumber is a helper routine that allocates a new +// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it. +func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber { + return &v +} + +// layerParser parses the input bytes and returns a Layer along with the next +// layerParser to run. If there is no more parsing to do, the returned +// layerParser is nil. +type layerParser func([]byte) (Layer, layerParser) + +// parse parses bytes starting with the first layerParser and using successive +// layerParsers until all the bytes are parsed. +func parse(parser layerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + +// parseEther parses the bytes assuming that they start with an ethernet header +// and continues parsing further encapsulations. +func parseEther(b []byte) (Layer, layerParser) { + h := header.Ethernet(b) + ether := Ether{ + SrcAddr: LinkAddress(h.SourceAddress()), + DstAddr: LinkAddress(h.DestinationAddress()), + Type: NetworkProtocolNumber(h.Type()), + } + var nextParser layerParser + switch h.Type() { + case header.IPv4ProtocolNumber: + nextParser = parseIPv4 + case header.IPv6ProtocolNumber: + nextParser = parseIPv6 + default: + // Assume that the rest is a payload. + nextParser = parsePayload + } + return ðer, nextParser +} + +func (l *Ether) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *Ether) length() int { + return header.EthernetMinimumSize +} + +// merge implements Layer.merge. +func (l *Ether) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv4 can construct and match an IPv4 encapsulation. +type IPv4 struct { + LayerBase + IHL *uint8 + TOS *uint8 + TotalLength *uint16 + ID *uint16 + Flags *uint8 + FragmentOffset *uint16 + TTL *uint8 + Protocol *uint8 + Checksum *uint16 + SrcAddr *tcpip.Address + DstAddr *tcpip.Address +} + +func (l *IPv4) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv4) ToBytes() ([]byte, error) { + b := make([]byte, header.IPv4MinimumSize) + h := header.IPv4(b) + fields := &header.IPv4Fields{ + IHL: 20, + TOS: 0, + TotalLength: 0, + ID: 0, + Flags: 0, + FragmentOffset: 0, + TTL: 64, + Protocol: 0, + Checksum: 0, + SrcAddr: tcpip.Address(""), + DstAddr: tcpip.Address(""), + } + if l.TOS != nil { + fields.TOS = *l.TOS + } + if l.TotalLength != nil { + fields.TotalLength = *l.TotalLength + } else { + fields.TotalLength = uint16(l.length()) + current := l.next() + for current != nil { + fields.TotalLength += uint16(current.length()) + current = current.next() + } + } + if l.ID != nil { + fields.ID = *l.ID + } + if l.Flags != nil { + fields.Flags = *l.Flags + } + if l.FragmentOffset != nil { + fields.FragmentOffset = *l.FragmentOffset + } + if l.TTL != nil { + fields.TTL = *l.TTL + } + if l.Protocol != nil { + fields.Protocol = *l.Protocol + } else { + switch n := l.next().(type) { + case *TCP: + fields.Protocol = uint8(header.TCPProtocolNumber) + case *UDP: + fields.Protocol = uint8(header.UDPProtocolNumber) + case *ICMPv4: + fields.Protocol = uint8(header.ICMPv4ProtocolNumber) + default: + // TODO(b/150301488): Support more protocols as needed. + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) + } + } + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + if l.Checksum != nil { + fields.Checksum = *l.Checksum + } + h.Encode(fields) + if l.Checksum == nil { + h.SetChecksum(^h.CalculateChecksum()) + } + return h, nil +} + +// Uint16 is a helper routine that allocates a new +// uint16 value to store v and returns a pointer to it. +func Uint16(v uint16) *uint16 { + return &v +} + +// Uint8 is a helper routine that allocates a new +// uint8 value to store v and returns a pointer to it. +func Uint8(v uint8) *uint8 { + return &v +} + +// Address is a helper routine that allocates a new tcpip.Address value to store +// v and returns a pointer to it. +func Address(v tcpip.Address) *tcpip.Address { + return &v +} + +// parseIPv4 parses the bytes assuming that they start with an ipv4 header and +// continues parsing further encapsulations. +func parseIPv4(b []byte) (Layer, layerParser) { + h := header.IPv4(b) + tos, _ := h.TOS() + ipv4 := IPv4{ + IHL: Uint8(h.HeaderLength()), + TOS: &tos, + TotalLength: Uint16(h.TotalLength()), + ID: Uint16(h.ID()), + Flags: Uint8(h.Flags()), + FragmentOffset: Uint16(h.FragmentOffset()), + TTL: Uint8(h.TTL()), + Protocol: Uint8(h.Protocol()), + Checksum: Uint16(h.Checksum()), + SrcAddr: Address(h.SourceAddress()), + DstAddr: Address(h.DestinationAddress()), + } + var nextParser layerParser + switch h.TransportProtocol() { + case header.TCPProtocolNumber: + nextParser = parseTCP + case header.UDPProtocolNumber: + nextParser = parseUDP + case header.ICMPv4ProtocolNumber: + nextParser = parseICMPv4 + default: + // Assume that the rest is a payload. + nextParser = parsePayload + } + return &ipv4, nextParser +} + +func (l *IPv4) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *IPv4) length() int { + if l.IHL == nil { + return header.IPv4MinimumSize + } + return int(*l.IHL) +} + +// merge implements Layer.merge. +func (l *IPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv6 can construct and match an IPv6 encapsulation. +type IPv6 struct { + LayerBase + TrafficClass *uint8 + FlowLabel *uint32 + PayloadLength *uint16 + NextHeader *uint8 + HopLimit *uint8 + SrcAddr *tcpip.Address + DstAddr *tcpip.Address +} + +func (l *IPv6) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6) ToBytes() ([]byte, error) { + b := make([]byte, header.IPv6MinimumSize) + h := header.IPv6(b) + fields := &header.IPv6Fields{ + HopLimit: 64, + } + if l.TrafficClass != nil { + fields.TrafficClass = *l.TrafficClass + } + if l.FlowLabel != nil { + fields.FlowLabel = *l.FlowLabel + } + if l.PayloadLength != nil { + fields.PayloadLength = *l.PayloadLength + } else { + for current := l.next(); current != nil; current = current.next() { + fields.PayloadLength += uint16(current.length()) + } + } + if l.NextHeader != nil { + fields.NextHeader = *l.NextHeader + } else { + switch n := l.next().(type) { + case *TCP: + fields.NextHeader = uint8(header.TCPProtocolNumber) + case *UDP: + fields.NextHeader = uint8(header.UDPProtocolNumber) + case *ICMPv6: + fields.NextHeader = uint8(header.ICMPv6ProtocolNumber) + case *IPv6HopByHopOptionsExtHdr: + fields.NextHeader = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier) + case *IPv6DestinationOptionsExtHdr: + fields.NextHeader = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) + default: + // TODO(b/150301488): Support more protocols as needed. + return nil, fmt.Errorf("ToBytes can't deduce the IPv6 header's next protocol: %#v", n) + } + } + if l.HopLimit != nil { + fields.HopLimit = *l.HopLimit + } + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + h.Encode(fields) + return h, nil +} + +// nextIPv6PayloadParser finds the corresponding parser for nextHeader. +func nextIPv6PayloadParser(nextHeader uint8) layerParser { + switch tcpip.TransportProtocolNumber(nextHeader) { + case header.TCPProtocolNumber: + return parseTCP + case header.UDPProtocolNumber: + return parseUDP + case header.ICMPv6ProtocolNumber: + return parseICMPv6 + } + switch header.IPv6ExtensionHeaderIdentifier(nextHeader) { + case header.IPv6HopByHopOptionsExtHdrIdentifier: + return parseIPv6HopByHopOptionsExtHdr + case header.IPv6DestinationOptionsExtHdrIdentifier: + return parseIPv6DestinationOptionsExtHdr + } + return parsePayload +} + +// parseIPv6 parses the bytes assuming that they start with an ipv6 header and +// continues parsing further encapsulations. +func parseIPv6(b []byte) (Layer, layerParser) { + h := header.IPv6(b) + tos, flowLabel := h.TOS() + ipv6 := IPv6{ + TrafficClass: &tos, + FlowLabel: &flowLabel, + PayloadLength: Uint16(h.PayloadLength()), + NextHeader: Uint8(h.NextHeader()), + HopLimit: Uint8(h.HopLimit()), + SrcAddr: Address(h.SourceAddress()), + DstAddr: Address(h.DestinationAddress()), + } + nextParser := nextIPv6PayloadParser(h.NextHeader()) + return &ipv6, nextParser +} + +func (l *IPv6) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *IPv6) length() int { + return header.IPv6MinimumSize +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions +// Extension Header. +type IPv6HopByHopOptionsExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + Options []byte +} + +// IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions +// Extension Header. +type IPv6DestinationOptionsExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + Options []byte +} + +// ipv6OptionsExtHdrToBytes serializes an options extension header into bytes. +func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, options []byte) []byte { + length := len(options) + 2 + bytes := make([]byte, length) + if nextHeader == nil { + bytes[0] = byte(header.IPv6NoNextHeaderIdentifier) + } else { + bytes[0] = byte(*nextHeader) + } + // ExtHdrLen field is the length of the extension header + // in 8-octet unit, ignoring the first 8 octets. + // https://tools.ietf.org/html/rfc2460#section-4.3 + // https://tools.ietf.org/html/rfc2460#section-4.6 + bytes[1] = uint8((length - 8) / 8) + copy(bytes[2:], options) + return bytes +} + +// IPv6ExtHdrIdent is a helper routine that allocates a new +// header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer +// to it. +func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier { + return &id +} + +// ToBytes implements Layer.ToBytes +func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) { + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.Options), nil +} + +// ToBytes implements Layer.ToBytes +func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) { + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.Options), nil +} + +// parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader +// field, the rest of the payload and a parser function for the corresponding +// next extension header. +func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) { + nextHeader := b[0] + // For HopByHop and Destination options extension headers, + // This field is the length of the extension header in + // 8-octet units, not including the first 8 octets. + // https://tools.ietf.org/html/rfc2460#section-4.3 + // https://tools.ietf.org/html/rfc2460#section-4.6 + length := b[1]*8 + 8 + data := b[2:length] + nextParser := nextIPv6PayloadParser(nextHeader) + return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser +} + +// parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start +// with an IPv6 HopByHop Options Extension Header. +func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) { + nextHeader, options, nextParser := parseIPv6ExtHdr(b) + return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser +} + +// parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start +// with an IPv6 Destination Options Extension Header. +func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) { + nextHeader, options, nextParser := parseIPv6ExtHdr(b) + return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser +} + +func (l *IPv6HopByHopOptionsExtHdr) length() int { + return len(l.Options) + 2 +} + +func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6HopByHopOptionsExtHdr) String() string { + return stringLayer(l) +} + +func (l *IPv6DestinationOptionsExtHdr) length() int { + return len(l.Options) + 2 +} + +func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6DestinationOptionsExtHdr) String() string { + return stringLayer(l) +} + +// ICMPv6 can construct and match an ICMPv6 encapsulation. +type ICMPv6 struct { + LayerBase + Type *header.ICMPv6Type + Code *byte + Checksum *uint16 + NDPPayload []byte +} + +func (l *ICMPv6) String() string { + // TODO(eyalsoha): Do something smarter here when *l.Type is ParameterProblem? + // We could parse the contents of the Payload as if it were an IPv6 packet. + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *ICMPv6) ToBytes() ([]byte, error) { + b := make([]byte, header.ICMPv6HeaderSize+len(l.NDPPayload)) + h := header.ICMPv6(b) + if l.Type != nil { + h.SetType(*l.Type) + } + if l.Code != nil { + h.SetCode(*l.Code) + } + copy(h.NDPPayload(), l.NDPPayload) + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + } else { + // It is possible that the ICMPv6 header does not follow the IPv6 header + // immediately, there could be one or more extension headers in between. + // We need to search forward to find the IPv6 header. + for prev := l.Prev(); prev != nil; prev = prev.Prev() { + if ipv6, ok := prev.(*IPv6); ok { + h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{})) + break + } + } + } + return h, nil +} + +// ICMPv6Type is a helper routine that allocates a new ICMPv6Type value to store +// v and returns a pointer to it. +func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type { + return &v +} + +// Byte is a helper routine that allocates a new byte value to store +// v and returns a pointer to it. +func Byte(v byte) *byte { + return &v +} + +// parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header. +func parseICMPv6(b []byte) (Layer, layerParser) { + h := header.ICMPv6(b) + icmpv6 := ICMPv6{ + Type: ICMPv6Type(h.Type()), + Code: Byte(h.Code()), + Checksum: Uint16(h.Checksum()), + NDPPayload: h.NDPPayload(), + } + return &icmpv6, nil +} + +func (l *ICMPv6) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *ICMPv6) length() int { + return header.ICMPv6HeaderSize + len(l.NDPPayload) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *ICMPv6) merge(other Layer) error { + return mergeLayer(l, other) +} + +// ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value +// to store t and returns a pointer to it. +func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type { + return &t +} + +// ICMPv4 can construct and match an ICMPv4 encapsulation. +type ICMPv4 struct { + LayerBase + Type *header.ICMPv4Type + Code *uint8 + Checksum *uint16 +} + +func (l *ICMPv4) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *ICMPv4) ToBytes() ([]byte, error) { + b := make([]byte, header.ICMPv4MinimumSize) + h := header.ICMPv4(b) + if l.Type != nil { + h.SetType(*l.Type) + } + if l.Code != nil { + h.SetCode(byte(*l.Code)) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + payload, err := payload(l) + if err != nil { + return nil, err + } + h.SetChecksum(header.ICMPv4Checksum(h, payload)) + return h, nil +} + +// parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a +// parser for the encapsulated payload. +func parseICMPv4(b []byte) (Layer, layerParser) { + h := header.ICMPv4(b) + icmpv4 := ICMPv4{ + Type: ICMPv4Type(h.Type()), + Code: Uint8(h.Code()), + Checksum: Uint16(h.Checksum()), + } + return &icmpv4, parsePayload +} + +func (l *ICMPv4) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *ICMPv4) length() int { + return header.ICMPv4MinimumSize +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *ICMPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + +// TCP can construct and match a TCP encapsulation. +type TCP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + SeqNum *uint32 + AckNum *uint32 + DataOffset *uint8 + Flags *uint8 + WindowSize *uint16 + Checksum *uint16 + UrgentPointer *uint16 + Options []byte +} + +func (l *TCP) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *TCP) ToBytes() ([]byte, error) { + b := make([]byte, l.length()) + h := header.TCP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.SeqNum != nil { + h.SetSequenceNumber(*l.SeqNum) + } + if l.AckNum != nil { + h.SetAckNumber(*l.AckNum) + } + if l.DataOffset != nil { + h.SetDataOffset(*l.DataOffset) + } else { + h.SetDataOffset(uint8(l.length())) + } + if l.Flags != nil { + h.SetFlags(*l.Flags) + } + if l.WindowSize != nil { + h.SetWindowSize(*l.WindowSize) + } else { + h.SetWindowSize(32768) + } + if l.UrgentPointer != nil { + h.SetUrgentPoiner(*l.UrgentPointer) + } + copy(b[header.TCPMinimumSize:], l.Options) + header.AddTCPOptionPadding(b[header.TCPMinimumSize:], len(l.Options)) + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setTCPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// totalLength returns the length of the provided layer and all following +// layers. +func totalLength(l Layer) int { + var totalLength int + for ; l != nil; l = l.next() { + totalLength += l.length() + } + return totalLength +} + +// payload returns a buffer.VectorisedView of l's payload. +func payload(l Layer) (buffer.VectorisedView, error) { + var payloadBytes buffer.VectorisedView + for current := l.next(); current != nil; current = current.next() { + payload, err := current.ToBytes() + if err != nil { + return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload) + } + payloadBytes.AppendView(payload) + } + return payloadBytes, nil +} + +// layerChecksum calculates the checksum of the Layer header, including the +// peusdeochecksum of the layer before it and all the bytes after it. +func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { + totalLength := uint16(totalLength(l)) + var xsum uint16 + switch s := l.Prev().(type) { + case *IPv4: + xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength) + default: + // TODO(b/150301488): Support more protocols, like IPv6. + return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s) + } + payloadBytes, err := payload(l) + if err != nil { + return 0, err + } + xsum = header.ChecksumVV(payloadBytes, xsum) + return xsum, nil +} + +// setTCPChecksum calculates the checksum of the TCP header and sets it in h. +func setTCPChecksum(h *header.TCP, tcp *TCP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(tcp, header.TCPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// Uint32 is a helper routine that allocates a new +// uint32 value to store v and returns a pointer to it. +func Uint32(v uint32) *uint32 { + return &v +} + +// parseTCP parses the bytes assuming that they start with a tcp header and +// continues parsing further encapsulations. +func parseTCP(b []byte) (Layer, layerParser) { + h := header.TCP(b) + tcp := TCP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + SeqNum: Uint32(h.SequenceNumber()), + AckNum: Uint32(h.AckNumber()), + DataOffset: Uint8(h.DataOffset()), + Flags: Uint8(h.Flags()), + WindowSize: Uint16(h.WindowSize()), + Checksum: Uint16(h.Checksum()), + UrgentPointer: Uint16(h.UrgentPointer()), + Options: b[header.TCPMinimumSize:h.DataOffset()], + } + return &tcp, parsePayload +} + +func (l *TCP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *TCP) length() int { + if l.DataOffset == nil { + // TCP header including the options must end on a 32-bit + // boundary; the user could potentially give us a slice + // whose length is not a multiple of 4 bytes, so we have + // to do the alignment here. + optlen := (len(l.Options) + 3) & ^3 + return header.TCPMinimumSize + optlen + } + return int(*l.DataOffset) +} + +// merge implements Layer.merge. +func (l *TCP) merge(other Layer) error { + return mergeLayer(l, other) +} + +// UDP can construct and match a UDP encapsulation. +type UDP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + Length *uint16 + Checksum *uint16 +} + +func (l *UDP) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *UDP) ToBytes() ([]byte, error) { + b := make([]byte, header.UDPMinimumSize) + h := header.UDP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.Length != nil { + h.SetLength(*l.Length) + } else { + h.SetLength(uint16(totalLength(l))) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setUDPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// setUDPChecksum calculates the checksum of the UDP header and sets it in h. +func setUDPChecksum(h *header.UDP, udp *UDP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(udp, header.UDPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// parseUDP parses the bytes assuming that they start with a udp header and +// returns the parsed layer and the next parser to use. +func parseUDP(b []byte) (Layer, layerParser) { + h := header.UDP(b) + udp := UDP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + Length: Uint16(h.Length()), + Checksum: Uint16(h.Checksum()), + } + return &udp, parsePayload +} + +func (l *UDP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *UDP) length() int { + return header.UDPMinimumSize +} + +// merge implements Layer.merge. +func (l *UDP) merge(other Layer) error { + return mergeLayer(l, other) +} + +// Payload has bytes beyond OSI layer 4. +type Payload struct { + LayerBase + Bytes []byte +} + +func (l *Payload) String() string { + return stringLayer(l) +} + +// parsePayload parses the bytes assuming that they start with a payload and +// continue to the end. There can be no further encapsulations. +func parsePayload(b []byte) (Layer, layerParser) { + payload := Payload{ + Bytes: b, + } + return &payload, nil +} + +// ToBytes implements Layer.ToBytes. +func (l *Payload) ToBytes() ([]byte, error) { + return l.Bytes, nil +} + +// Length returns payload byte length. +func (l *Payload) Length() int { + return l.length() +} + +func (l *Payload) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *Payload) length() int { + return len(l.Bytes) +} + +// merge implements Layer.merge. +func (l *Payload) merge(other Layer) error { + return mergeLayer(l, other) +} + +// Layers is an array of Layer and supports similar functions to Layer. +type Layers []Layer + +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { + for i, l := range *ls { + if i > 0 { + l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) + } + if i+1 < len(*ls) { + l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) + } + } +} + +// ToBytes converts the Layers into bytes. It creates a linked list of the Layer +// structs and then concatentates the output of ToBytes on each Layer. +func (ls *Layers) ToBytes() ([]byte, error) { + ls.linkLayers() + outBytes := []byte{} + for _, l := range *ls { + layerBytes, err := l.ToBytes() + if err != nil { + return nil, err + } + outBytes = append(outBytes, layerBytes...) + } + return outBytes, nil +} + +func (ls *Layers) match(other Layers) bool { + if len(*ls) > len(other) { + return false + } + for i, l := range *ls { + if !equalLayer(l, other[i]) { + return false + } + } + return true +} + +// layerDiff stores the diffs for each field along with the label for the Layer. +// If rows is nil, that means that there was no diff. +type layerDiff struct { + label string + rows []layerDiffRow +} + +// layerDiffRow stores the fields and corresponding values for two got and want +// layers. If the value was nil then the string stored is the empty string. +type layerDiffRow struct { + field, got, want string +} + +// diffLayer extracts all differing fields between two layers. +func diffLayer(got, want Layer) []layerDiffRow { + vGot := reflect.ValueOf(got).Elem() + vWant := reflect.ValueOf(want).Elem() + if vGot.Type() != vWant.Type() { + return nil + } + t := vGot.Type() + var result []layerDiffRow + for i := 0; i < t.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + vGot := vGot.Field(i) + vWant := vWant.Field(i) + gotString := "" + if !vGot.IsNil() { + gotString = fmt.Sprint(reflect.Indirect(vGot)) + } + wantString := "" + if !vWant.IsNil() { + wantString = fmt.Sprint(reflect.Indirect(vWant)) + } + result = append(result, layerDiffRow{t.Name, gotString, wantString}) + } + return result +} + +// layerType returns a concise string describing the type of the Layer, like +// "TCP", or "IPv6". +func layerType(l Layer) string { + return reflect.TypeOf(l).Elem().Name() +} + +// diff compares Layers and returns a representation of the difference. Each +// Layer in the Layers is pairwise compared. If an element in either is nil, it +// is considered a match with the other Layer. If two Layers have differing +// types, they don't match regardless of the contents. If two Layers have the +// same type then the fields in the Layer are pairwise compared. Fields that are +// nil always match. Two non-nil fields only match if they point to equal +// values. diff returns an empty string if and only if *ls and other match. +func (ls *Layers) diff(other Layers) string { + var allDiffs []layerDiff + // Check the cases where one list is longer than the other, where one or both + // elements are nil, where the sides have different types, and where the sides + // have the same type. + for i := 0; i < len(*ls) || i < len(other); i++ { + if i >= len(*ls) { + // Matching ls against other where other is longer than ls. missing + // matches everything so we just include a label without any rows. Having + // no rows is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: "missing matches " + layerType(other[i]), + }) + continue + } + + if i >= len(other) { + // Matching ls against other where ls is longer than other. missing + // matches everything so we just include a label without any rows. Having + // no rows is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " matches missing", + }) + continue + } + + if (*ls)[i] == nil && other[i] == nil { + // Matching ls against other where both elements are nil. nil matches + // everything so we just include a label without any rows. Having no rows + // is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: "nil matches nil", + }) + continue + } + + if (*ls)[i] == nil { + // Matching ls against other where the element in ls is nil. nil matches + // everything so we just include a label without any rows. Having no rows + // is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: "nil matches " + layerType(other[i]), + }) + continue + } + + if other[i] == nil { + // Matching ls against other where the element in other is nil. nil + // matches everything so we just include a label without any rows. Having + // no rows is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " matches nil", + }) + continue + } + + if reflect.TypeOf((*ls)[i]) == reflect.TypeOf(other[i]) { + // Matching ls against other where both elements have the same type. Match + // each field pairwise and only report a diff if there is a mismatch, + // which is only when both sides are non-nil and have differring values. + diff := diffLayer((*ls)[i], other[i]) + var layerDiffRows []layerDiffRow + for _, d := range diff { + if d.got == "" || d.want == "" || d.got == d.want { + continue + } + layerDiffRows = append(layerDiffRows, layerDiffRow{ + d.field, + d.got, + d.want, + }) + } + if len(layerDiffRows) > 0 { + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]), + rows: layerDiffRows, + }) + } else { + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " matches " + layerType(other[i]), + // Having no rows is a sign that there was no diff. + }) + } + continue + } + // Neither side is nil and the types are different, so we'll display one + // side then the other. + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " doesn't match " + layerType(other[i]), + }) + diff := diffLayer((*ls)[i], (*ls)[i]) + layerDiffRows := []layerDiffRow{} + for _, d := range diff { + if len(d.got) == 0 { + continue + } + layerDiffRows = append(layerDiffRows, layerDiffRow{ + d.field, + d.got, + "", + }) + } + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]), + rows: layerDiffRows, + }) + + layerDiffRows = []layerDiffRow{} + diff = diffLayer(other[i], other[i]) + for _, d := range diff { + if len(d.want) == 0 { + continue + } + layerDiffRows = append(layerDiffRows, layerDiffRow{ + d.field, + "", + d.want, + }) + } + allDiffs = append(allDiffs, layerDiff{ + label: layerType(other[i]), + rows: layerDiffRows, + }) + } + + output := "" + // These are for output formatting. + maxLabelLen, maxFieldLen, maxGotLen, maxWantLen := 0, 0, 0, 0 + foundOne := false + for _, l := range allDiffs { + if len(l.label) > maxLabelLen && len(l.rows) > 0 { + maxLabelLen = len(l.label) + } + if l.rows != nil { + foundOne = true + } + for _, r := range l.rows { + if len(r.field) > maxFieldLen { + maxFieldLen = len(r.field) + } + if l := len(fmt.Sprint(r.got)); l > maxGotLen { + maxGotLen = l + } + if l := len(fmt.Sprint(r.want)); l > maxWantLen { + maxWantLen = l + } + } + } + if !foundOne { + return "" + } + for _, l := range allDiffs { + if len(l.rows) == 0 { + output += "(" + l.label + ")\n" + continue + } + for i, r := range l.rows { + var label string + if i == 0 { + label = l.label + ":" + } + output += fmt.Sprintf( + "%*s %*s %*v %*v\n", + maxLabelLen+1, label, + maxFieldLen+1, r.field+":", + maxGotLen, r.got, + maxWantLen, r.want, + ) + } + } + return output +} + +// merge merges the other Layers into ls. If the other Layers is longer, those +// additional Layer structs are added to ls. The errors from merging are +// collected and returned. +func (ls *Layers) merge(other Layers) error { + var errs error + for i, o := range other { + if i < len(*ls) { + errs = multierr.Combine(errs, (*ls)[i].merge(o)) + } else { + *ls = append(*ls, o) + } + } + return errs +} diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go new file mode 100644 index 000000000..382a983a1 --- /dev/null +++ b/test/packetimpact/testbench/layers_test.go @@ -0,0 +1,618 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "bytes" + "net" + "testing" + + "github.com/mohae/deepcopy" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestLayerMatch(t *testing.T) { + var nilPayload *Payload + noPayload := &Payload{} + emptyPayload := &Payload{Bytes: []byte{}} + fullPayload := &Payload{Bytes: []byte{1, 2, 3}} + emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}} + fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}} + for _, tt := range []struct { + a, b Layer + want bool + }{ + {nilPayload, nilPayload, true}, + {nilPayload, noPayload, true}, + {nilPayload, emptyPayload, true}, + {nilPayload, fullPayload, true}, + {noPayload, noPayload, true}, + {noPayload, emptyPayload, true}, + {noPayload, fullPayload, true}, + {emptyPayload, emptyPayload, true}, + {emptyPayload, fullPayload, false}, + {fullPayload, fullPayload, true}, + {emptyTCP, fullTCP, true}, + } { + if got := tt.a.match(tt.b); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want) + } + if got := tt.b.match(tt.a); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want) + } + } +} + +func TestLayerMergeMismatch(t *testing.T) { + tcp := &TCP{} + otherTCP := &TCP{} + ipv4 := &IPv4{} + ether := &Ether{} + for _, tt := range []struct { + a, b Layer + success bool + }{ + {tcp, tcp, true}, + {tcp, otherTCP, true}, + {tcp, ipv4, false}, + {tcp, ether, false}, + {tcp, nil, true}, + + {otherTCP, otherTCP, true}, + {otherTCP, ipv4, false}, + {otherTCP, ether, false}, + {otherTCP, nil, true}, + + {ipv4, ipv4, true}, + {ipv4, ether, false}, + {ipv4, nil, true}, + + {ether, ether, true}, + {ether, nil, true}, + } { + if err := tt.a.merge(tt.b); (err == nil) != tt.success { + t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.a, tt.b, err) + } + if tt.b != nil { + if err := tt.b.merge(tt.a); (err == nil) != tt.success { + t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.b, tt.a, err) + } + } + } +} + +func TestLayerMerge(t *testing.T) { + zero := Uint32(0) + one := Uint32(1) + two := Uint32(2) + empty := []byte{} + foo := []byte("foo") + bar := []byte("bar") + for _, tt := range []struct { + a, b Layer + want Layer + }{ + {&TCP{AckNum: nil}, &TCP{AckNum: nil}, &TCP{AckNum: nil}}, + {&TCP{AckNum: nil}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: nil}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: nil}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: nil}, nil, &TCP{AckNum: nil}}, + + {&TCP{AckNum: zero}, &TCP{AckNum: nil}, &TCP{AckNum: zero}}, + {&TCP{AckNum: zero}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: zero}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: zero}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: zero}, nil, &TCP{AckNum: zero}}, + + {&TCP{AckNum: one}, &TCP{AckNum: nil}, &TCP{AckNum: one}}, + {&TCP{AckNum: one}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: one}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: one}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: one}, nil, &TCP{AckNum: one}}, + + {&TCP{AckNum: two}, &TCP{AckNum: nil}, &TCP{AckNum: two}}, + {&TCP{AckNum: two}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: two}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: two}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: two}, nil, &TCP{AckNum: two}}, + + {&Payload{Bytes: nil}, &Payload{Bytes: nil}, &Payload{Bytes: nil}}, + {&Payload{Bytes: nil}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: nil}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: nil}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: nil}, nil, &Payload{Bytes: nil}}, + + {&Payload{Bytes: empty}, &Payload{Bytes: nil}, &Payload{Bytes: empty}}, + {&Payload{Bytes: empty}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: empty}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: empty}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: empty}, nil, &Payload{Bytes: empty}}, + + {&Payload{Bytes: foo}, &Payload{Bytes: nil}, &Payload{Bytes: foo}}, + {&Payload{Bytes: foo}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: foo}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: foo}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: foo}, nil, &Payload{Bytes: foo}}, + + {&Payload{Bytes: bar}, &Payload{Bytes: nil}, &Payload{Bytes: bar}}, + {&Payload{Bytes: bar}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: bar}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: bar}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: bar}, nil, &Payload{Bytes: bar}}, + } { + a := deepcopy.Copy(tt.a).(Layer) + if err := a.merge(tt.b); err != nil { + t.Errorf("%s.merge(%s) = %s, wanted nil", tt.a, tt.b, err) + continue + } + if a.String() != tt.want.String() { + t.Errorf("%s.merge(%s) merge result got %s, want %s", tt.a, tt.b, a, tt.want) + } + } +} + +func TestLayerStringFormat(t *testing.T) { + for _, tt := range []struct { + name string + l Layer + want string + }{ + { + name: "TCP", + l: &TCP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + SeqNum: Uint32(3452155723), + AckNum: Uint32(2596996163), + DataOffset: Uint8(5), + Flags: Uint8(20), + WindowSize: Uint16(64240), + Checksum: Uint16(0x2e2b), + }, + want: "&testbench.TCP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "SeqNum:3452155723 " + + "AckNum:2596996163 " + + "DataOffset:5 " + + "Flags:20 " + + "WindowSize:64240 " + + "Checksum:11819" + + "}", + }, + { + name: "UDP", + l: &UDP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + Length: Uint16(12), + }, + want: "&testbench.UDP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "Length:12" + + "}", + }, + { + name: "IPv4", + l: &IPv4{ + IHL: Uint8(5), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(0), + Flags: Uint8(2), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(6), + Checksum: Uint16(0x2e2b), + SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), + DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), + }, + want: "&testbench.IPv4{" + + "IHL:5 " + + "TOS:0 " + + "TotalLength:44 " + + "ID:0 " + + "Flags:2 " + + "FragmentOffset:0 " + + "TTL:64 " + + "Protocol:6 " + + "Checksum:11819 " + + "SrcAddr:197.34.63.10 " + + "DstAddr:197.34.63.20" + + "}", + }, + { + name: "Ether", + l: &Ether{ + SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), + DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), + Type: NetworkProtocolNumber(4), + }, + want: "&testbench.Ether{" + + "SrcAddr:02:42:c5:22:3f:0a " + + "DstAddr:02:42:c5:22:3f:14 " + + "Type:4" + + "}", + }, + { + name: "Payload", + l: &Payload{ + Bytes: []byte("Hooray for packetimpact."), + }, + want: "&testbench.Payload{Bytes:\n" + + "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" + + "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" + + "}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.String(); got != tt.want { + t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) + } + }) + } +} + +func TestConnectionMatch(t *testing.T) { + conn := Connection{ + layerStates: []layerState{ðerState{}}, + } + protoNum0 := tcpip.NetworkProtocolNumber(0) + protoNum1 := tcpip.NetworkProtocolNumber(1) + for _, tt := range []struct { + description string + override, received Layers + wantMatch bool + }{ + { + description: "shorter override", + override: []Layer{&Ether{}}, + received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + wantMatch: true, + }, + { + description: "longer override", + override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + received: []Layer{&Ether{}}, + wantMatch: false, + }, + { + description: "ether layer mismatch", + override: []Layer{&Ether{Type: &protoNum0}}, + received: []Layer{&Ether{Type: &protoNum1}}, + wantMatch: false, + }, + { + description: "both nil", + override: nil, + received: nil, + wantMatch: false, + }, + { + description: "nil override", + override: nil, + received: []Layer{&Ether{}}, + wantMatch: true, + }, + } { + t.Run(tt.description, func(t *testing.T) { + if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch { + t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch) + } + }) + } +} + +func TestLayersDiff(t *testing.T) { + for _, tt := range []struct { + x, y Layers + want string + }{ + { + Layers{&Ether{Type: NetworkProtocolNumber(12)}, &TCP{DataOffset: Uint8(5), SeqNum: Uint32(5)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "Ether: Type: 12 13\n" + + " TCP: SeqNum: 5 6\n" + + " DataOffset: 5 7\n", + }, + { + Layers{&Ether{Type: NetworkProtocolNumber(12)}, &UDP{SrcPort: Uint16(123)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "Ether: Type: 12 13\n" + + "(UDP doesn't match TCP)\n" + + " UDP: SrcPort: 123 \n" + + " TCP: SeqNum: 6\n" + + " DataOffset: 7\n", + }, + { + Layers{&UDP{SrcPort: Uint16(123)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "(UDP doesn't match Ether)\n" + + " UDP: SrcPort: 123 \n" + + "Ether: Type: 13\n" + + "(missing matches TCP)\n", + }, + { + Layers{nil, &UDP{SrcPort: Uint16(123)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "(nil matches Ether)\n" + + "(UDP doesn't match TCP)\n" + + "UDP: SrcPort: 123 \n" + + "TCP: SeqNum: 6\n" + + " DataOffset: 7\n", + }, + { + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(4)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(6)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "(Ether matches Ether)\n" + + "IPv4: IHL: 4 6\n" + + "(TCP matches TCP)\n", + }, + { + Layers{&Payload{Bytes: []byte("foo")}}, + Layers{&Payload{Bytes: []byte("bar")}}, + "Payload: Bytes: [102 111 111] [98 97 114]\n", + }, + { + Layers{&Payload{Bytes: []byte("")}}, + Layers{&Payload{}}, + "", + }, + { + Layers{&Payload{Bytes: []byte("")}}, + Layers{&Payload{Bytes: []byte("")}}, + "", + }, + { + Layers{&UDP{}}, + Layers{&TCP{}}, + "(UDP doesn't match TCP)\n" + + "(UDP)\n" + + "(TCP)\n", + }, + } { + if got := tt.x.diff(tt.y); got != tt.want { + t.Errorf("%s.diff(%s) = %q, want %q", tt.x, tt.y, got, tt.want) + } + if tt.x.match(tt.y) != (tt.x.diff(tt.y) == "") { + t.Errorf("match and diff of %s and %s disagree", tt.x, tt.y) + } + if tt.y.match(tt.x) != (tt.y.diff(tt.x) == "") { + t.Errorf("match and diff of %s and %s disagree", tt.y, tt.x) + } + } +} + +func TestTCPOptions(t *testing.T) { + for _, tt := range []struct { + description string + wantBytes []byte + wantLayers Layers + }{ + { + description: "without payload", + wantBytes: []byte{ + // IPv4 Header + 0x45, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06, + 0xf9, 0x77, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01, + // TCP Header + 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xf5, 0x1c, 0x00, 0x00, + // WindowScale Option + 0x03, 0x03, 0x02, + // NOP Option + 0x00, + }, + wantLayers: []Layer{ + &IPv4{ + IHL: Uint8(20), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(1), + Flags: Uint8(0), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(uint8(header.TCPProtocolNumber)), + Checksum: Uint16(0xf977), + SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())), + DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())), + }, + &TCP{ + SrcPort: Uint16(12345), + DstPort: Uint16(54321), + SeqNum: Uint32(0), + AckNum: Uint32(0), + Flags: Uint8(header.TCPFlagSyn), + WindowSize: Uint16(8192), + Checksum: Uint16(0xf51c), + UrgentPointer: Uint16(0), + Options: []byte{3, 3, 2, 0}, + }, + &Payload{Bytes: nil}, + }, + }, + { + description: "with payload", + wantBytes: []byte{ + // IPv4 header + 0x45, 0x00, 0x00, 0x37, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06, + 0xf9, 0x6c, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01, + // TCP header + 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xe5, 0x21, 0x00, 0x00, + // WindowScale Option + 0x03, 0x03, 0x02, + // NOP Option + 0x00, + // Payload: "Sample Data" + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv4{ + IHL: Uint8(20), + TOS: Uint8(0), + TotalLength: Uint16(55), + ID: Uint16(1), + Flags: Uint8(0), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(uint8(header.TCPProtocolNumber)), + Checksum: Uint16(0xf96c), + SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())), + DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())), + }, + &TCP{ + SrcPort: Uint16(12345), + DstPort: Uint16(54321), + SeqNum: Uint32(0), + AckNum: Uint32(0), + Flags: Uint8(header.TCPFlagSyn), + WindowSize: Uint16(8192), + Checksum: Uint16(0xe521), + UrgentPointer: Uint16(0), + Options: []byte{3, 3, 2, 0}, + }, + &Payload{Bytes: []byte("Sample Data")}, + }, + }, + } { + t.Run(tt.description, func(t *testing.T) { + layers := parse(parseIPv4, tt.wantBytes) + if !layers.match(tt.wantLayers) { + t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) + } + gotBytes, err := layers.ToBytes() + if err != nil { + t.Fatalf("ToBytes() failed on %s: %s", &layers, err) + } + if !bytes.Equal(tt.wantBytes, gotBytes) { + t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes) + } + }) + } +} + +func TestIPv6ExtHdrOptions(t *testing.T) { + for _, tt := range []struct { + description string + wantBytes []byte + wantLayers Layers + }{ + { + description: "IPv6/HopByHop", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &Payload{ + Bytes: nil, + }, + }, + }, + { + description: "IPv6/HopByHop/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + { + description: "IPv6/HopByHop/Destination/ICMPv6", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Destination Options + 0x3a, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // ICMPv6 Param Problem + 0x04, 0x00, 0x5f, 0x98, 0x00, 0x00, 0x00, 0x06, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6DestinationOptionsExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6DestinationOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &ICMPv6{ + Type: ICMPv6Type(header.ICMPv6ParamProblem), + Code: Byte(0), + Checksum: Uint16(0x5f98), + NDPPayload: []byte{0x00, 0x00, 0x00, 0x06}, + }, + }, + }, + } { + t.Run(tt.description, func(t *testing.T) { + layers := parse(parseIPv6, tt.wantBytes) + if !layers.match(tt.wantLayers) { + t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) + } + gotBytes, err := layers.ToBytes() + if err != nil { + t.Fatalf("ToBytes() failed on %s: %s", &layers, err) + } + if !bytes.Equal(tt.wantBytes, gotBytes) { + t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go new file mode 100644 index 000000000..278229b7e --- /dev/null +++ b/test/packetimpact/testbench/rawsockets.go @@ -0,0 +1,178 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "encoding/binary" + "fmt" + "math" + "net" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Sniffer can sniff raw packets on the wire. +type Sniffer struct { + t *testing.T + fd int +} + +func htons(x uint16) uint16 { + buf := [2]byte{} + binary.BigEndian.PutUint16(buf[:], x) + return usermem.ByteOrder.Uint16(buf[:]) +} + +// NewSniffer creates a Sniffer connected to *device. +func NewSniffer(t *testing.T) (Sniffer, error) { + snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) + if err != nil { + return Sniffer{}, err + } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil { + t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err) + } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil { + t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) + } + return Sniffer{ + t: t, + fd: snifferFd, + }, nil +} + +// maxReadSize should be large enough for the maximum frame size in bytes. If a +// packet too large for the buffer arrives, the test will get a fatal error. +const maxReadSize int = 65536 + +// Recv tries to read one frame until the timeout is up. +func (s *Sniffer) Recv(timeout time.Duration) []byte { + deadline := time.Now().Add(timeout) + for { + timeout = deadline.Sub(time.Now()) + if timeout <= 0 { + return nil + } + whole, frac := math.Modf(timeout.Seconds()) + tv := unix.Timeval{ + Sec: int64(whole), + Usec: int64(frac * float64(time.Microsecond/time.Second)), + } + + if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { + s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) + } + + buf := make([]byte, maxReadSize) + nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN { + // There was a timeout. + continue + } + if err != nil { + s.t.Fatalf("can't read: %s", err) + } + if nread > maxReadSize { + s.t.Fatalf("received a truncated frame of %d bytes", nread) + } + return buf[:nread] + } +} + +// Drain drains the Sniffer's socket receive buffer by receiving until there's +// nothing else to receive. +func (s *Sniffer) Drain() { + s.t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) + if err != nil { + s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { + s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + } + for { + buf := make([]byte, maxReadSize) + _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { + s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + } +} + +// close the socket that Sniffer is using. +func (s *Sniffer) close() error { + if err := unix.Close(s.fd); err != nil { + return fmt.Errorf("can't close sniffer socket: %w", err) + } + s.fd = -1 + return nil +} + +// Injector can inject raw frames. +type Injector struct { + t *testing.T + fd int +} + +// NewInjector creates a new injector on *device. +func NewInjector(t *testing.T) (Injector, error) { + ifInfo, err := net.InterfaceByName(Device) + if err != nil { + return Injector{}, err + } + + var haddr [8]byte + copy(haddr[:], ifInfo.HardwareAddr) + sa := unix.SockaddrLinklayer{ + Protocol: unix.ETH_P_IP, + Ifindex: ifInfo.Index, + Halen: uint8(len(ifInfo.HardwareAddr)), + Addr: haddr, + } + + injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) + if err != nil { + return Injector{}, err + } + if err := unix.Bind(injectFd, &sa); err != nil { + return Injector{}, err + } + return Injector{ + t: t, + fd: injectFd, + }, nil +} + +// Send a raw frame. +func (i *Injector) Send(b []byte) { + if _, err := unix.Write(i.fd, b); err != nil { + i.t.Fatalf("can't write: %s of len %d", err, len(b)) + } +} + +// close the underlying socket. +func (i *Injector) close() error { + if err := unix.Close(i.fd); err != nil { + return fmt.Errorf("can't close sniffer socket: %w", err) + } + i.fd = -1 + return nil +} diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go new file mode 100644 index 000000000..d64f32a5b --- /dev/null +++ b/test/packetimpact/testbench/testbench.go @@ -0,0 +1,106 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "flag" + "fmt" + "math/rand" + "net" + "os/exec" + "testing" + "time" + + "gvisor.dev/gvisor/test/packetimpact/netdevs" +) + +var ( + // DUTType is the type of device under test. + DUTType = "" + // Device is the local device on the test network. + Device = "" + // LocalIPv4 is the local IPv4 address on the test network. + LocalIPv4 = "" + // LocalIPv6 is the local IPv6 address on the test network. + LocalIPv6 = "" + // LocalMAC is the local MAC address on the test network. + LocalMAC = "" + // POSIXServerIP is the POSIX server's IP address on the control network. + POSIXServerIP = "" + // POSIXServerPort is the UDP port the POSIX server is bound to on the + // control network. + POSIXServerPort = 40000 + // RemoteIPv4 is the DUT's IPv4 address on the test network. + RemoteIPv4 = "" + // RemoteIPv6 is the DUT's IPv6 address on the test network. + RemoteIPv6 = "" + // RemoteMAC is the DUT's MAC address on the test network. + RemoteMAC = "" + // RPCKeepalive is the gRPC keepalive. + RPCKeepalive = 10 * time.Second + // RPCTimeout is the gRPC timeout. + RPCTimeout = 100 * time.Millisecond +) + +// RegisterFlags defines flags and associates them with the package-level +// exported variables above. It should be called by tests in their init +// functions. +func RegisterFlags(fs *flag.FlagSet) { + fs.StringVar(&POSIXServerIP, "posix_server_ip", POSIXServerIP, "ip address to listen to for UDP commands") + fs.IntVar(&POSIXServerPort, "posix_server_port", POSIXServerPort, "port to listen to for UDP commands") + fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout") + fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive") + fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets") + fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets") + fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets") + fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") + fs.StringVar(&Device, "device", Device, "local device for test packets") + fs.StringVar(&DUTType, "dut_type", DUTType, "type of device under test") +} + +// genPseudoFlags populates flag-like global config based on real flags. +// +// genPseudoFlags must only be called after flag.Parse. +func genPseudoFlags() error { + out, err := exec.Command("ip", "addr", "show").CombinedOutput() + if err != nil { + return fmt.Errorf("listing devices: %q: %w", string(out), err) + } + devs, err := netdevs.ParseDevices(string(out)) + if err != nil { + return fmt.Errorf("parsing devices: %w", err) + } + + _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(LocalIPv4), devs) + if err != nil { + return fmt.Errorf("can't find deviceInfo: %w", err) + } + + LocalMAC = deviceInfo.MAC.String() + LocalIPv6 = deviceInfo.IPv6Addr.String() + + return nil +} + +// GenerateRandomPayload generates a random byte slice of the specified length, +// causing a fatal test failure if it is unable to do so. +func GenerateRandomPayload(t *testing.T, n int) []byte { + t.Helper() + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + return buf +} diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD new file mode 100644 index 000000000..3ecbe83eb --- /dev/null +++ b/test/packetimpact/tests/BUILD @@ -0,0 +1,264 @@ +load("//test/packetimpact/runner:defs.bzl", "packetimpact_go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +packetimpact_go_test( + name = "fin_wait2_timeout", + srcs = ["fin_wait2_timeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "ipv4_id_uniqueness", + srcs = ["ipv4_id_uniqueness_test.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_recv_multicast", + srcs = ["udp_recv_multicast_test.go"], + # TODO(b/152813495): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_icmp_error_propagation", + srcs = ["udp_icmp_error_propagation_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_reordering", + srcs = ["tcp_reordering_test.go"], + # TODO(b/139368047): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_window_shrink", + srcs = ["tcp_window_shrink_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_zero_window_probe", + srcs = ["tcp_zero_window_probe_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_zero_window_probe_retransmit", + srcs = ["tcp_zero_window_probe_retransmit_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_zero_window_probe_usertimeout", + srcs = ["tcp_zero_window_probe_usertimeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_retransmits", + srcs = ["tcp_retransmits_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_outside_the_window", + srcs = ["tcp_outside_the_window_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_noaccept_close_rst", + srcs = ["tcp_noaccept_close_rst_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_send_window_sizes_piggyback", + srcs = ["tcp_send_window_sizes_piggyback_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_close_wait_ack", + srcs = ["tcp_close_wait_ack_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_paws_mechanism", + srcs = ["tcp_paws_mechanism_test.go"], + # TODO(b/156682000): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_user_timeout", + srcs = ["tcp_user_timeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_queue_receive_in_syn_sent", + srcs = ["tcp_queue_receive_in_syn_sent_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_synsent_reset", + srcs = ["tcp_synsent_reset_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_synrcvd_reset", + srcs = ["tcp_synrcvd_reset_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_cork_mss", + srcs = ["tcp_cork_mss_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_handshake_window_size", + srcs = ["tcp_handshake_window_size_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "icmpv6_param_problem", + srcs = ["icmpv6_param_problem_test.go"], + # TODO(b/153485026): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "ipv6_unknown_options_action", + srcs = ["ipv6_unknown_options_action_test.go"], + # TODO(b/159928940): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_send_recv_dgram", + srcs = ["udp_send_recv_dgram_test.go"], + deps = [ + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go new file mode 100644 index 000000000..407565078 --- /dev/null +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fin_wait2_timeout_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestFinWait2Timeout(t *testing.T) { + for _, tt := range []struct { + description string + linger2 bool + }{ + {"WithLinger2", true}, + {"WithoutLinger2", false}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Connect() + + acceptFd, _ := dut.Accept(listenFd) + if tt.linger2 { + tv := unix.Timeval{Sec: 1, Usec: 0} + dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) + } + dut.Close(acceptFd) + + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) + } + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + time.Sleep(5 * time.Second) + conn.Drain() + + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if tt.linger2 { + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST packet within a second but got none: %s", err) + } + } else { + if got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { + t.Fatalf("expected no RST packets within ten seconds but got one: %s", got) + } + } + }) + } +} diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go new file mode 100644 index 000000000..4d1d9a7f5 --- /dev/null +++ b/test/packetimpact/tests/icmpv6_param_problem_test.go @@ -0,0 +1,78 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmpv6_param_problem_test + +import ( + "encoding/binary" + "flag" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT +// should respond with an ICMPv6 Parameter Problem message. +func TestICMPv6ParamProblemTest(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + defer conn.Close() + ipv6 := testbench.IPv6{ + // 254 is reserved and used for experimentation and testing. This should + // cause an error. + NextHeader: testbench.Uint8(254), + } + icmpv6 := testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + NDPPayload: []byte("hello world"), + } + + toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6) + (*testbench.Connection)(&conn).SendFrame(toSend) + + // Build the expected ICMPv6 payload, which includes an index to the + // problematic byte and also the problematic packet as described in + // https://tools.ietf.org/html/rfc4443#page-12 . + ipv6Sent := toSend[1:] + expectedPayload, err := ipv6Sent.ToBytes() + if err != nil { + t.Fatalf("can't convert %s to bytes: %s", ipv6Sent, err) + } + + // The problematic field is the NextHeader. + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, header.IPv6NextHeaderOffset) + expectedPayload = append(b, expectedPayload...) + expectedICMPv6 := testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + NDPPayload: expectedPayload, + } + + paramProblem := testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &expectedICMPv6, + } + timeout := time.Second + if _, err := conn.ExpectFrame(paramProblem, timeout); err != nil { + t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err) + } +} diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go new file mode 100644 index 000000000..70f6df5e0 --- /dev/null +++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go @@ -0,0 +1,122 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_id_uniqueness_test + +import ( + "context" + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func recvTCPSegment(conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { + layers, err := conn.ExpectData(expect, expectPayload, time.Second) + if err != nil { + return 0, fmt.Errorf("failed to receive TCP segment: %s", err) + } + if len(layers) < 2 { + return 0, fmt.Errorf("got packet with layers: %v, expected to have at least 2 layers (link and network)", layers) + } + ipv4, ok := layers[1].(*testbench.IPv4) + if !ok { + return 0, fmt.Errorf("got network layer: %T, expected: *IPv4", layers[1]) + } + if *ipv4.Flags&header.IPv4FlagDontFragment != 0 { + return 0, fmt.Errorf("got IPv4 DF=1, expected DF=0") + } + return *ipv4.ID, nil +} + +// RFC 6864 section 4.2 states: "The IPv4 ID of non-atomic datagrams MUST NOT +// be reused when sending a copy of an earlier non-atomic datagram." +// +// This test creates a TCP connection, uses the IP_MTU_DISCOVER socket option +// to force the DF bit to be 0, and checks that a retransmitted segment has a +// different IPv4 Identification value than the original segment. +func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + payload []byte + }{ + {"SmallPayload", []byte("sample data")}, + // 512 bytes is chosen because sending more than this in a single segment + // causes the retransmission to send less than the original amount. + {"512BytePayload", testbench.GenerateRandomPayload(t, 512)}, + } { + t.Run(tc.name, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + remoteFD, _ := dut.Accept(listenFD) + defer dut.Close(remoteFD) + + dut.SetSockOptInt(remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + // TODO(b/129291778) The following socket option clears the DF bit on + // IP packets sent over the socket, and is currently not supported by + // gVisor. gVisor by default sends packets with DF=0 anyway, so the + // socket option being not supported does not affect the operation of + // this test. Once the socket option is supported, the following call + // can be changed to simply assert success. + ret, errno := dut.SetSockOptIntWithErrno(context.Background(), remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) + if ret == -1 && errno != unix.ENOTSUP { + t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno) + } + + samplePayload := &testbench.Payload{Bytes: tc.payload} + + dut.Send(remoteFD, tc.payload, 0) + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err) + } + // Let the DUT estimate RTO with RTT from the DATA-ACK. + // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which + // we can skip sending this ACK. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + dut.Send(remoteFD, tc.payload, 0) + expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum()))} + originalID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + if err != nil { + t.Fatalf("failed to receive TCP segment: %s", err) + } + + retransmitID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + if err != nil { + t.Fatalf("failed to receive retransmitted TCP segment: %s", err) + } + if originalID == retransmitID { + t.Fatalf("unexpectedly got retransmitted TCP segment with same IPv4 ID field=%d", originalID) + } + }) + } +} diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go new file mode 100644 index 000000000..d301d8829 --- /dev/null +++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go @@ -0,0 +1,187 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_unknown_options_action_test + +import ( + "encoding/binary" + "flag" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + tb.RegisterFlags(flag.CommandLine) +} + +func mkHopByHopOptionsExtHdr(optType byte) tb.Layer { + return &tb.IPv6HopByHopOptionsExtHdr{ + Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, + } +} + +func mkDestinationOptionsExtHdr(optType byte) tb.Layer { + return &tb.IPv6DestinationOptionsExtHdr{ + Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, + } +} + +func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte { + return byte(action << 6) +} + +func TestIPv6UnknownOptionAction(t *testing.T) { + for _, tt := range []struct { + description string + mkExtHdr func(optType byte) tb.Layer + action header.IPv6OptionUnknownAction + multicastDst bool + wantICMPv6 bool + }{ + { + description: "0b00/hbh", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionSkip, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b01/hbh", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscard, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b10/hbh/unicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b10/hbh/multicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: true, + wantICMPv6: true, + }, + { + description: "0b11/hbh/unicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b11/hbh/multicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: true, + wantICMPv6: false, + }, + { + description: "0b00/destination", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionSkip, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b01/destination", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscard, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b10/destination/unicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b10/destination/multicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: true, + wantICMPv6: true, + }, + { + description: "0b11/destination/unicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b11/destination/multicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: true, + wantICMPv6: false, + }, + } { + t.Run(tt.description, func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + ipv6Conn := tb.NewIPv6Conn(t, tb.IPv6{}, tb.IPv6{}) + conn := (*tb.Connection)(&ipv6Conn) + defer ipv6Conn.Close() + + outgoingOverride := tb.Layers{} + if tt.multicastDst { + outgoingOverride = tb.Layers{&tb.IPv6{ + DstAddr: tb.Address(tcpip.Address(net.ParseIP("ff02::1"))), + }} + } + + outgoing := conn.CreateFrame(outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action))) + conn.SendFrame(outgoing) + ipv6Sent := outgoing[1:] + invokingPacket, err := ipv6Sent.ToBytes() + if err != nil { + t.Fatalf("failed to serialize the outgoing packet: %s", err) + } + icmpv6Payload := make([]byte, 4) + // The pointer in the ICMPv6 parameter problem message should point to + // the option type of the unknown option. In our test case, it is the + // first option in the extension header whose option type is 2 bytes + // after the IPv6 header (after NextHeader and ExtHdrLen). + binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2) + icmpv6Payload = append(icmpv6Payload, invokingPacket...) + gotICMPv6, err := ipv6Conn.ExpectFrame(tb.Layers{ + &tb.Ether{}, + &tb.IPv6{}, + &tb.ICMPv6{ + Type: tb.ICMPv6Type(header.ICMPv6ParamProblem), + Code: tb.Byte(2), + NDPPayload: icmpv6Payload, + }, + }, time.Second) + if tt.wantICMPv6 && err != nil { + t.Fatalf("expected ICMPv6 Parameter Problem but got none: %s", err) + } + if !tt.wantICMPv6 && gotICMPv6 != nil { + t.Fatalf("expected no ICMPv6 Parameter Problem but got one: %s", gotICMPv6) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go new file mode 100644 index 000000000..6e7ff41d7 --- /dev/null +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -0,0 +1,108 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_close_wait_ack_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestCloseWaitAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {"OTW", GenerateOTWSeqSegment, 0, false}, + {"OTW", GenerateOTWSeqSegment, 1, true}, + {"OTW", GenerateOTWSeqSegment, 2, true}, + {"ACK", GenerateUnaccACKSegment, 0, false}, + {"ACK", GenerateUnaccACKSegment, 1, true}, + {"ACK", GenerateUnaccACKSegment, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + acceptFd, _ := dut.Accept(listenFd) + + // Send a FIN to DUT to intiate the active close + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + // Send a segment with OTW Seq / unacc ACK and expect an ACK back + conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(acceptFd) + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the +// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK +// is expected from the receiver. +func GenerateOTWSeqSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.LocalSeqNum().Add(windowSize) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} +} + +// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated +// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is +// expected from the receiver. +func GenerateUnaccACKSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.RemoteSeqNum() + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go new file mode 100644 index 000000000..fb8f48629 --- /dev/null +++ b/test/packetimpact/tests/tcp_cork_mss_test.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_cork_mss_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPCorkMSS tests for segment coalesce and split as per MSS. +func TestTCPCorkMSS(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + const mss = uint32(header.TCPDefaultMSS) + options := make([]byte, header.TCPOptionMSSLength) + header.EncodeMSSOption(mss, options) + conn.ConnectWithOptions(options) + + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + dut.SetSockOptInt(acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) + + // Let the dut application send 2 small segments to be held up and coalesced + // until the application sends a larger segment to fill up to > MSS. + sampleData := []byte("Sample Data") + dut.Send(acceptFD, sampleData, 0) + dut.Send(acceptFD, sampleData, 0) + + expectedData := sampleData + expectedData = append(expectedData, sampleData...) + largeData := make([]byte, mss+1) + expectedData = append(expectedData, largeData...) + dut.Send(acceptFD, largeData, 0) + + // Expect the segments to be coalesced and sent and capped to MSS. + expectedPayload := testbench.Payload{Bytes: expectedData[:mss]} + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Expect the coalesced segment to be split and transmitted. + expectedPayload = testbench.Payload{Bytes: expectedData[mss:]} + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Check for segments to *not* be held up because of TCP_CORK when + // the current send window is less than MSS. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) + dut.Send(acceptFD, sampleData, 0) + dut.Send(acceptFD, sampleData, 0) + expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)} + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go new file mode 100644 index 000000000..652b530d0 --- /dev/null +++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_handshake_window_size_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPHandshakeWindowSize tests if the stack is honoring the window size +// communicated during handshake. +func TestTCPHandshakeWindowSize(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + // Start handshake with zero window size. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK: %s", err) + } + // Update the advertised window size to a non-zero value with the ACK that + // completes the handshake. + // + // Set the window size with MSB set and expect the dut to treat it as + // an unsigned value. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) + + acceptFd, _ := dut.Accept(listenFD) + defer dut.Close(acceptFd) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Since we advertised a zero window followed by a non-zero window, + // expect the dut to honor the recently advertised non-zero window + // and actually send out the data instead of probing for zero window. + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go new file mode 100644 index 000000000..b9b3e91d3 --- /dev/null +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_noaccept_close_rst_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestTcpNoAcceptCloseReset(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + conn.Connect() + defer conn.Close() + dut.Close(listenFd) + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + t.Fatalf("expected a RST-ACK packet but got none: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go new file mode 100644 index 000000000..ad8c74234 --- /dev/null +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -0,0 +1,93 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_outside_the_window_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive +// that are inside or outside the TCP window. Packets that are outside the +// window should force an extra ACK, as described in RFC793 page 69: +// https://tools.ietf.org/html/rfc793#page-69 +func TestTCPOutsideTheWindow(t *testing.T) { + for _, tt := range []struct { + description string + tcpFlags uint8 + payload []testbench.Layer + seqNumOffset seqnum.Size + expectACK bool + }{ + {"SYN", header.TCPFlagSyn, nil, 0, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true}, + {"ACK", header.TCPFlagAck, nil, 0, false}, + {"FIN", header.TCPFlagFin, nil, 0, false}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 0, true}, + + {"SYN", header.TCPFlagSyn, nil, 1, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true}, + {"ACK", header.TCPFlagAck, nil, 1, true}, + {"FIN", header.TCPFlagFin, nil, 1, false}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 1, true}, + + {"SYN", header.TCPFlagSyn, nil, 2, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true}, + {"ACK", header.TCPFlagAck, nil, 2, true}, + {"FIN", header.TCPFlagFin, nil, 2, false}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Connect() + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset + conn.Drain() + // Ignore whatever incrementing that this out-of-order packet might cause + // to the AckNum. + localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum())) + conn.Send(testbench.TCP{ + Flags: testbench.Uint8(tt.tcpFlags), + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + }, tt.payload...) + timeout := 3 * time.Second + gotACK, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + if tt.expectACK && err != nil { + t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) + } + if !tt.expectACK && gotACK != nil { + t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go new file mode 100644 index 000000000..55db4ece6 --- /dev/null +++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_paws_mechanism_test + +import ( + "encoding/hex" + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestPAWSMechanism(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + options := make([]byte, header.TCPOptionTSLength) + header.EncodeTSOption(currentTS(), 0, options) + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) + synAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("didn't get synack during handshake: %s", err) + } + parsedSynOpts := header.ParseSynOptions(synAck.Options, true) + if !parsedSynOpts.TS { + t.Fatalf("expected TSOpt from DUT, options we got:\n%s", hex.Dump(synAck.Options)) + } + tsecr := parsedSynOpts.TSVal + header.EncodeTSOption(currentTS(), tsecr, options) + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + sampleData := []byte("Sample Data") + sentTSVal := currentTS() + header.EncodeTSOption(sentTSVal, tsecr, options) + // 3ms here is chosen arbitrarily to make sure we have increasing timestamps + // every time we send one, it should not cause any flakiness because timestamps + // only need to be non-decreasing. + time.Sleep(3 * time.Millisecond) + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + + gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK but got none: %s", err) + } + + parsedOpts := header.ParseTCPOptions(gotTCP.Options) + if !parsedOpts.TS { + t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options)) + } + if parsedOpts.TSVal < tsecr { + t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr) + } + if parsedOpts.TSEcr != sentTSVal { + t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal) + } + tsecr = parsedOpts.TSVal + lastAckNum := gotTCP.AckNum + + badTSVal := sentTSVal - 100 + header.EncodeTSOption(badTSVal, tsecr, options) + // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness + // due to the exact same reasoning discussed above. + time.Sleep(3 * time.Millisecond) + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + + gotTCP, err = conn.Expect(testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err) + } + parsedOpts = header.ParseTCPOptions(gotTCP.Options) + if !parsedOpts.TS { + t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options)) + } + if parsedOpts.TSVal < tsecr { + t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr) + } + if parsedOpts.TSEcr != sentTSVal { + t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal) + } +} + +func currentTS() uint32 { + return uint32(time.Now().UnixNano() / 1e6) +} diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go new file mode 100644 index 000000000..8fbec893b --- /dev/null +++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_queue_receive_in_syn_sent_test + +import ( + "bytes" + "context" + "encoding/hex" + "errors" + "flag" + "net" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestQueueReceiveInSynSent tests receive behavior when the TCP state +// is SYN-SENT. +// It tests for 2 variants where the receive is blocked and: +// (1) we complete handshake and send sample data. +// (2) we send a TCP RST. +func TestQueueReceiveInSynSent(t *testing.T) { + for _, tt := range []struct { + description string + reset bool + }{ + {description: "Send DATA", reset: false}, + {description: "Send RST", reset: true}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + sampleData := []byte("Sample Data") + + dut.SetNonBlocking(socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), socket, conn.LocalAddr()); !errors.Is(err, syscall.EINPROGRESS) { + t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) + } + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + t.Fatalf("expected a SYN from DUT, but got none: %s", err) + } + + if _, _, err := dut.RecvWithErrno(context.Background(), socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { + t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + } + + // Test blocking read. + dut.SetNonBlocking(socket, false) + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + block.Done() + // Issue RECEIVE call in SYN-SENT, this should be queued for + // process until the connection is established. + n, buff, err := dut.RecvWithErrno(ctx, socket, int32(len(sampleData)), 0) + if tt.reset { + if err != syscall.Errno(unix.ECONNREFUSED) { + t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + } + if n != -1 { + t.Errorf("expected return value %d, got %d", -1, n) + } + return + } + if n == -1 { + t.Errorf("failed to recv on DUT: %s", err) + } + if got := buff[:n]; !bytes.Equal(got, sampleData) { + t.Errorf("received data doesn't match, got:\n%s, want:\n%s", hex.Dump(got), hex.Dump(sampleData)) + } + }() + + // Wait for the goroutine to be scheduled and before it + // blocks on endpoint receive. + block.Wait() + // The following sleep is used to prevent the connection + // from being established before we are blocked on Recv. + time.Sleep(100 * time.Millisecond) + + if tt.reset { + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + return + } + + // Bring the connection to Established. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + + // Send sample payload and expect an ACK. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go new file mode 100644 index 000000000..a5378a9dd --- /dev/null +++ b/test/packetimpact/tests/tcp_reordering_test.go @@ -0,0 +1,174 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reordering_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + tb.RegisterFlags(flag.CommandLine) +} + +func TestReorderingWindow(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + // Enable SACK. + opts := make([]byte, 40) + optsOff := 0 + optsOff += header.EncodeNOP(opts[optsOff:]) + optsOff += header.EncodeNOP(opts[optsOff:]) + optsOff += header.EncodeSACKPermittedOption(opts[optsOff:]) + + // Ethernet guarantees that the MTU is at least 1500 bytes. + const minMTU = 1500 + const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize + optsOff += header.EncodeMSSOption(mss, opts[optsOff:]) + + conn.ConnectWithOptions(opts[:optsOff]) + + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + if tb.DUTType == "linux" { + // Linux has changed its handling of reordering, force the old behavior. + dut.SetSockOpt(acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) + } + + pls := dut.GetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) + if tb.DUTType == "netstack" { + // netstack does not impliment TCP_MAXSEG correctly. Fake it + // here. Netstack uses the max SACK size which is 32. The MSS + // option is 8 bytes, making the total 36 bytes. + pls = mss - 36 + } + + payload := make([]byte, pls) + + seqNum1 := *conn.RemoteSeqNum() + const numPkts = 10 + // Send some packets, checking that we receive each. + for i, sn := 0, seqNum1; i < numPkts; i++ { + dut.Send(acceptFd, payload, 0) + + gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + sn.UpdateForward(seqnum.Size(len(payload))) + if err != nil { + t.Errorf("Expect #%d: %s", i+1, err) + continue + } + if gotOne == nil { + t.Errorf("#%d: expected a packet within a second but got none", i+1) + } + } + + seqNum2 := *conn.RemoteSeqNum() + + // SACK packets #2-4. + sackBlock := make([]byte, 40) + sbOff := 0 + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + seqNum1.Add(seqnum.Size(len(payload))), + seqNum1.Add(seqnum.Size(4 * len(payload))), + }}, sackBlock[sbOff:]) + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + + // ACK first packet. + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) + + // Check for retransmit. + gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) + if err != nil { + t.Error("Expect for retransmit:", err) + } + if gotOne == nil { + t.Error("expected a retransmitted packet within a second but got none") + } + + // ACK all send packets with a DSACK block for packet #1. This tells + // the other end that we got both the original and retransmit for + // packet #1. + dsackBlock := make([]byte, 40) + dsbOff := 0 + dsbOff += header.EncodeNOP(dsackBlock[dsbOff:]) + dsbOff += header.EncodeNOP(dsackBlock[dsbOff:]) + dsbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + seqNum1.Add(seqnum.Size(len(payload))), + seqNum1.Add(seqnum.Size(4 * len(payload))), + }}, dsackBlock[dsbOff:]) + + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) + + // Send half of the original window of packets, checking that we + // received each. + for i, sn := 0, seqNum2; i < numPkts/2; i++ { + dut.Send(acceptFd, payload, 0) + + gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + sn.UpdateForward(seqnum.Size(len(payload))) + if err != nil { + t.Errorf("Expect #%d: %s", i+1, err) + continue + } + if gotOne == nil { + t.Errorf("#%d: expected a packet within a second but got none", i+1) + } + } + + if tb.DUTType == "netstack" { + // The window should now be halved, so we should receive any + // more, even if we send them. + dut.Send(acceptFd, payload, 0) + if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) + } + return + } + + // Linux reduces the window by three. Check that we can receive the rest. + for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ { + dut.Send(acceptFd, payload, 0) + + gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + sn.UpdateForward(seqnum.Size(len(payload))) + if err != nil { + t.Errorf("Expect #%d: %s", i+1, err) + continue + } + if gotOne == nil { + t.Errorf("#%d: expected a packet within a second but got none", i+1) + } + } + + // The window should now be full. + dut.Send(acceptFd, payload, 0) + if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) + } +} diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go new file mode 100644 index 000000000..6940eb7fb --- /dev/null +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_retransmits_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestRetransmits tests retransmits occur at exponentially increasing +// time intervals. +func TestRetransmits(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. + // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which + // we can skip sending this ACK. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + startRTO := time.Second + current := startRTO + first := time.Now() + dut.Send(acceptFd, sampleData, 0) + seq := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + // Expect retransmits of the same segment. + for i := 0; i < 5; i++ { + start := time.Now() + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { + t.Fatalf("expected payload was not received: %s loop %d", err, i) + } + if i == 0 { + startRTO = time.Now().Sub(first) + current = 2 * startRTO + continue + } + // Check if the probes came at exponentially increasing intervals. + if p := time.Since(start); p < current-startRTO { + t.Fatalf("retransmit came sooner interval %d probe %d", p, i) + } + current *= 2 + } +} diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go new file mode 100644 index 000000000..90ab85419 --- /dev/null +++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_send_window_sizes_piggyback_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestSendWindowSizesPiggyback tests cases where segment sizes are close to +// sender window size and checks for ACK piggybacking for each of those case. +func TestSendWindowSizesPiggyback(t *testing.T) { + sampleData := []byte("Sample Data") + segmentSize := uint16(len(sampleData)) + // Advertise receive window sizes that are lesser, equal to or greater than + // enqueued segment size and check for segment transmits. The test attempts + // to enqueue a segment on the dut before acknowledging previous segment and + // lets the dut piggyback any ACKs along with the enqueued segment. + for _, tt := range []struct { + description string + windowSize uint16 + expectedPayload1 []byte + expectedPayload2 []byte + enqueue bool + }{ + // Expect the first segment to be split as it cannot be accomodated in + // the sender window. This means we need not enqueue a new segment after + // the first segment. + {"WindowSmallerThanSegment", segmentSize - 1, sampleData[:(segmentSize - 1)], sampleData[(segmentSize - 1):], false /* enqueue */}, + + {"WindowEqualToSegment", segmentSize, sampleData, sampleData, true /* enqueue */}, + + // Expect the second segment to not be split as its size is greater than + // the available sender window size. The segments should not be split + // when there is pending unacknowledged data and the segment-size is + // greater than available sender window. + {"WindowGreaterThanSegment", segmentSize + 1, sampleData, sampleData, true /* enqueue */}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} + + dut.Send(acceptFd, sampleData, 0) + expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1} + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Expect any enqueued segment to be transmitted by the dut along with + // piggybacked ACK for our data. + + if tt.enqueue { + // Enqueue a segment for the dut to transmit. + dut.Send(acceptFd, sampleData, 0) + } + + // Send ACK for the previous segment along with data for the dut to + // receive and ACK back. Sending this ACK would make room for the dut + // to transmit any enqueued segment. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) + + // Expect the dut to piggyback the ACK for received data along with + // the segment enqueued for transmit. + expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2} + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go new file mode 100644 index 000000000..7d5deab01 --- /dev/null +++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_syn_reset_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPSynRcvdReset tests transition from SYN-RCVD to CLOSED. +func TestTCPSynRcvdReset(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + // Expect dut connection to have transitioned to SYN-RCVD state. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK %s", err) + } + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + // Expect the connection to have transitioned SYN-RCVD to CLOSED. + // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go new file mode 100644 index 000000000..6898a2239 --- /dev/null +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_synsent_reset_test + +import ( + "flag" + "net" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + tb.RegisterFlags(flag.CommandLine) +} + +// dutSynSentState sets up the dut connection in SYN-SENT state. +func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { + dut := tb.NewDUT(t) + + clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) + port := uint16(9001) + conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port}) + + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4()) + // Bring the dut to SYN-SENT state with a non-blocking connect. + dut.Connect(clientFD, &sa) + if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN\n") + } + + return &dut, &conn, port, clientPort +} + +// TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition. +func TestTCPSynSentReset(t *testing.T) { + dut, conn, _, _ := dutSynSentState(t) + defer conn.Close() + defer dut.TearDown() + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + // Expect the connection to have closed. + // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST") + } +} + +// TestTCPSynSentRcvdReset tests RFC793, p70, SYN-SENT to SYN-RCVD to CLOSED +// transitions. +func TestTCPSynSentRcvdReset(t *testing.T) { + dut, c, remotePort, clientPort := dutSynSentState(t) + defer dut.TearDown() + defer c.Close() + + conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort}) + defer conn.Close() + // Initiate new SYN connection with the same port pair + // (simultaneous open case), expect the dut connection to move to + // SYN-RCVD state + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK %s\n", err) + } + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) + // Expect the connection to have transitioned SYN-RCVD to CLOSED. + // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST") + } +} diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go new file mode 100644 index 000000000..87e45d765 --- /dev/null +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_user_timeout_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func sendPayload(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { + sampleData := make([]byte, 100) + for i := range sampleData { + sampleData[i] = uint8(i) + } + conn.Drain() + dut.Send(fd, sampleData, 0) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + return fmt.Errorf("expected data but got none: %w", err) + } + return nil +} + +func sendFIN(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { + dut.Close(fd) + return nil +} + +func TestTCPUserTimeout(t *testing.T) { + for _, tt := range []struct { + description string + userTimeout time.Duration + sendDelay time.Duration + }{ + {"NoUserTimeout", 0, 3 * time.Second}, + {"ACKBeforeUserTimeout", 5 * time.Second, 4 * time.Second}, + {"ACKAfterUserTimeout", 5 * time.Second, 7 * time.Second}, + } { + for _, ttf := range []struct { + description string + f func(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error + }{ + {"AfterPayload", sendPayload}, + {"AfterFIN", sendFIN}, + } { + t.Run(tt.description+ttf.description, func(t *testing.T) { + // Create a socket, listen, TCP handshake, and accept. + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Connect() + acceptFD, _ := dut.Accept(listenFD) + + if tt.userTimeout != 0 { + dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) + } + + if err := ttf.f(&conn, &dut, acceptFD); err != nil { + t.Fatal(err) + } + + time.Sleep(tt.sendDelay) + conn.Drain() + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + // If TCP_USER_TIMEOUT was set and the above delay was longer than the + // TCP_USER_TIMEOUT then the DUT should send a RST in response to the + // testbench's packet. + expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout + expectTimeout := 5 * time.Second + got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) + if expectRST && err != nil { + t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) + } + if !expectRST && got != nil { + t.Errorf("expected no RST packet within %s but got one: %s", expectTimeout, got) + } + }) + } + } +} diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go new file mode 100644 index 000000000..e78d04756 --- /dev/null +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_window_shrink_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestWindowShrink(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + dut.Send(acceptFd, sampleData, 0) + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + // We close our receiving window here + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + + dut.Send(acceptFd, []byte("Sample Data"), 0) + // Note: There is another kind of zero-window probing which Windows uses (by sending one + // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change + // the following lines. + expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err) + } +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go new file mode 100644 index 000000000..8c89d57c9 --- /dev/null +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_zero_window_probe_retransmit_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestZeroWindowProbeRetransmit tests retransmits of zero window probes +// to be sent at exponentially inreasing time intervals. +func TestZeroWindowProbeRetransmit(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Send and receive sample data to the dut. + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + + // Check for the dut to keep the connection alive as long as the zero window + // probes are acknowledged. Check if the zero window probes are sent at + // exponentially increasing intervals. The timeout intervals are function + // of the recorded first zero probe transmission duration. + // + // Advertize zero receive window again. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + + startProbeDuration := time.Second + current := startProbeDuration + first := time.Now() + // Ask the dut to send out data. + dut.Send(acceptFd, sampleData, 0) + // Expect the dut to keep the connection alive as long as the remote is + // acknowledging the zero-window probes. + for i := 0; i < 5; i++ { + start := time.Now() + // Expect zero-window probe with a timeout which is a function of the typical + // first retransmission time. The retransmission times is supposed to + // exponentially increase. + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { + t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i) + } + if i == 0 { + startProbeDuration = time.Now().Sub(first) + current = 2 * startProbeDuration + continue + } + // Check if the probes came at exponentially increasing intervals. + if got, want := time.Since(start), current-startProbeDuration; got < want { + t.Errorf("got zero probe %d after %s, want >= %s", i, got, want) + } + // Acknowledge the zero-window probes from the dut. + conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + current *= 2 + } + // Advertize non-zero window. + conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + // Expect the dut to recover and transmit data. + if _, err := conn.ExpectData(&testbench. + TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go new file mode 100644 index 000000000..649fd5699 --- /dev/null +++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go @@ -0,0 +1,112 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_zero_window_probe_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestZeroWindowProbe tests few cases of zero window probing over the +// same connection. +func TestZeroWindowProbe(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + start := time.Now() + // Send and receive sample data to the dut. + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + sendTime := time.Now().Sub(start) + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + + // Test 1: Check for receive of a zero window probe, record the duration for + // probe to be sent. + // + // Advertize zero window to the dut. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + + // Expected sequence number of the zero window probe. + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + // Expected ack number of the ACK for the probe. + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + + // Expect there are no zero-window probes sent until there is data to be sent out + // from the dut. + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { + t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err) + } + + start = time.Now() + // Ask the dut to send out data. + dut.Send(acceptFd, sampleData, 0) + // Expect zero-window probe from the dut. + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) + } + // Expect the probe to be sent after some time. Compare against the previous + // time recorded when the dut immediately sends out data on receiving the + // send command. + if startProbeDuration := time.Now().Sub(start); startProbeDuration <= sendTime { + t.Fatalf("expected the first probe to be sent out after retransmission interval, got %s want > %s", startProbeDuration, sendTime) + } + + // Test 2: Check if the dut recovers on advertizing non-zero receive window. + // and sends out the sample payload after the send window opens. + // + // Advertize non-zero window to the dut and ack the zero window probe. + conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + // Expect the dut to recover and transmit data. + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Test 3: Sanity check for dut's processing of a similar probe it sent. + // Check if the dut responds as we do for a similar probe sent to it. + // Basically with sequence number to one byte behind the unacknowledged + // sequence number. + p := testbench.Uint32(uint32(*conn.LocalSeqNum())) + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum() - 1))}) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with ack number: %d: %s", p, err) + } +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go new file mode 100644 index 000000000..3c467b14f --- /dev/null +++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go @@ -0,0 +1,98 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_zero_window_probe_usertimeout_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestZeroWindowProbeUserTimeout sanity tests user timeout when we are +// retransmitting zero window probes. +func TestZeroWindowProbeUserTimeout(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Connect() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Send and receive sample data to the dut. + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + + // Test 1: Check for receive of a zero window probe, record the duration for + // probe to be sent. + // + // Advertize zero window to the dut. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + + // Expected sequence number of the zero window probe. + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + start := time.Now() + // Ask the dut to send out data. + dut.Send(acceptFd, sampleData, 0) + // Expect zero-window probe from the dut. + if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) + } + // Record the duration for first probe, the dut sends the zero window probe after + // a retransmission time interval. + startProbeDuration := time.Now().Sub(start) + + // Test 2: Check if the dut times out the connection by honoring usertimeout + // when the dut is sending zero-window probes. + // + // Reduce the retransmit timeout. + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) + // Advertize zero window again. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + // Ask the dut to send out data that would trigger zero window probe retransmissions. + dut.Send(acceptFd, sampleData, 0) + + // Wait for the connection to timeout after multiple zero-window probe retransmissions. + time.Sleep(8 * startProbeDuration) + + // Expect the connection to have timed out and closed which would cause the dut + // to reply with a RST to the ACK we send. + conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST") + } +} diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go new file mode 100644 index 000000000..b754918f6 --- /dev/null +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -0,0 +1,365 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_icmp_error_propagation_test + +import ( + "context" + "flag" + "fmt" + "net" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +type connectionMode bool + +func (c connectionMode) String() string { + if c { + return "Connected" + } + return "Connectionless" +} + +type icmpError int + +const ( + portUnreachable icmpError = iota + timeToLiveExceeded +) + +func (e icmpError) String() string { + switch e { + case portUnreachable: + return "PortUnreachable" + case timeToLiveExceeded: + return "TimeToLiveExpired" + } + return "Unknown ICMP error" +} + +func (e icmpError) ToICMPv4() *testbench.ICMPv4 { + switch e { + case portUnreachable: + return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4PortUnreachable)} + case timeToLiveExceeded: + return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), Code: testbench.Uint8(header.ICMPv4TTLExceeded)} + } + return nil +} + +type errorDetection struct { + name string + useValidConn bool + f func(context.Context, testData) error +} + +type testData struct { + dut *testbench.DUT + conn *testbench.UDPIPv4 + remoteFD int32 + remotePort uint16 + cleanFD int32 + cleanPort uint16 + wantErrno syscall.Errno +} + +// wantErrno computes the errno to expect given the connection mode of a UDP +// socket and the ICMP error it will receive. +func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno { + if c && icmpErr == portUnreachable { + return syscall.Errno(unix.ECONNREFUSED) + } + return syscall.Errno(0) +} + +// sendICMPError sends an ICMP error message in response to a UDP datagram. +func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error { + layers := (*testbench.Connection)(conn).CreateFrame(nil) + layers = layers[:len(layers)-1] + ip, ok := udp.Prev().(*testbench.IPv4) + if !ok { + return fmt.Errorf("expected %s to be IPv4", udp.Prev()) + } + if icmpErr == timeToLiveExceeded { + *ip.TTL = 1 + // Let serialization recalculate the checksum since we set the TTL + // to 1. + ip.Checksum = nil + } + // Note that the ICMP payload is valid in this case because the UDP + // payload is empty. If the UDP payload were not empty, the packet + // length during serialization may not be calculated correctly, + // resulting in a mal-formed packet. + layers = append(layers, icmpErr.ToICMPv4(), ip, udp) + + (*testbench.Connection)(conn).SendFrameStateless(layers) + return nil +} + +// testRecv tests observing the ICMP error through the recv syscall. A packet +// is sent to the DUT, and if wantErrno is non-zero, then the first recv should +// fail and the second should succeed. Otherwise if wantErrno is zero then the +// first recv should succeed immediately. +func testRecv(ctx context.Context, d testData) error { + // Check that receiving on the clean socket works. + d.conn.Send(testbench.UDP{DstPort: &d.cleanPort}) + d.dut.Recv(d.cleanFD, 100, 0) + + d.conn.Send(testbench.UDP{}) + + if d.wantErrno != syscall.Errno(0) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + ret, _, err := d.dut.RecvWithErrno(ctx, d.remoteFD, 100, 0) + if ret != -1 { + return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + } + if err != d.wantErrno { + return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + } + } + + d.dut.Recv(d.remoteFD, 100, 0) + return nil +} + +// testSendTo tests observing the ICMP error through the send syscall. If +// wantErrno is non-zero, the first send should fail and a subsequent send +// should suceed; while if wantErrno is zero then the first send should just +// succeed. +func testSendTo(ctx context.Context, d testData) error { + // Check that sending on the clean socket works. + d.dut.SendTo(d.cleanFD, nil, 0, d.conn.LocalAddr()) + if _, err := d.conn.Expect(testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { + return fmt.Errorf("did not receive UDP packet from clean socket on DUT: %s", err) + } + + if d.wantErrno != syscall.Errno(0) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + ret, err := d.dut.SendToWithErrno(ctx, d.remoteFD, nil, 0, d.conn.LocalAddr()) + + if ret != -1 { + return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + } + if err != d.wantErrno { + return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + } + } + + d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) + if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { + return fmt.Errorf("did not receive UDP packet as expected: %s", err) + } + return nil +} + +func testSockOpt(_ context.Context, d testData) error { + // Check that there's no pending error on the clean socket. + if errno := syscall.Errno(d.dut.GetSockOptInt(d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { + return fmt.Errorf("unexpected error (%[1]d) %[1]v on clean socket", errno) + } + + if errno := syscall.Errno(d.dut.GetSockOptInt(d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { + return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) + } + + // Check that after clearing socket error, sending doesn't fail. + d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) + if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { + return fmt.Errorf("did not receive UDP packet as expected: %s", err) + } + return nil +} + +// TestUDPICMPErrorPropagation tests that ICMP error messages in response to +// UDP datagrams are processed correctly. RFC 1122 section 4.1.3.3 states that: +// "UDP MUST pass to the application layer all ICMP error messages that it +// receives from the IP layer." +// +// The test cases are parametrized in 3 dimensions: 1. the UDP socket is either +// put into connection mode or left connectionless, 2. the ICMP message type +// and code, and 3. the method by which the ICMP error is observed on the +// socket: sendto, recv, or getsockopt(SO_ERROR). +// +// Linux's udp(7) man page states: "All fatal errors will be passed to the user +// as an error return even when the socket is not connected. This includes +// asynchronous errors received from the network." In practice, the only +// combination of parameters to the test that causes an error to be observable +// on the UDP socket is receiving a port unreachable message on a connected +// socket. +func TestUDPICMPErrorPropagation(t *testing.T) { + for _, connect := range []connectionMode{true, false} { + for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} { + wantErrno := wantErrno(connect, icmpErr) + + for _, errDetect := range []errorDetection{ + errorDetection{"SendTo", false, testSendTo}, + // Send to an address that's different from the one that caused an ICMP + // error to be returned. + errorDetection{"SendToValid", true, testSendTo}, + errorDetection{"Recv", false, testRecv}, + errorDetection{"SockOpt", false, testSockOpt}, + } { + t.Run(fmt.Sprintf("%s/%s/%s", connect, icmpErr, errDetect.name), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(remoteFD) + + // Create a second, clean socket on the DUT to ensure that the ICMP + // error messages only affect the sockets they are intended for. + cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(cleanFD) + + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close() + + if connect { + dut.Connect(remoteFD, conn.LocalAddr()) + dut.Connect(cleanFD, conn.LocalAddr()) + } + + dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) + udp, err := conn.Expect(testbench.UDP{}, time.Second) + if err != nil { + t.Fatalf("did not receive message from DUT: %s", err) + } + + if err := sendICMPError(&conn, icmpErr, udp); err != nil { + t.Fatal(err) + } + + errDetectConn := &conn + if errDetect.useValidConn { + // connClean is a UDP socket on the test runner that was not + // involved in the generation of the ICMP error. As such, + // interactions between it and the the DUT should be independent of + // the ICMP error at least at the port level. + connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer connClean.Close() + + errDetectConn = &connClean + } + + if err := errDetect.f(context.Background(), testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}); err != nil { + t.Fatal(err) + } + }) + } + } + } +} + +// TestICMPErrorDuringUDPRecv tests behavior when a UDP socket is in the middle +// of a blocking recv and receives an ICMP error. +func TestICMPErrorDuringUDPRecv(t *testing.T) { + for _, connect := range []connectionMode{true, false} { + for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} { + wantErrno := wantErrno(connect, icmpErr) + + t.Run(fmt.Sprintf("%s/%s", connect, icmpErr), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(remoteFD) + + // Create a second, clean socket on the DUT to ensure that the ICMP + // error messages only affect the sockets they are intended for. + cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(cleanFD) + + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close() + + if connect { + dut.Connect(remoteFD, conn.LocalAddr()) + dut.Connect(cleanFD, conn.LocalAddr()) + } + + dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) + udp, err := conn.Expect(testbench.UDP{}, time.Second) + if err != nil { + t.Fatalf("did not receive message from DUT: %s", err) + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + + if wantErrno != syscall.Errno(0) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0) + if ret != -1 { + t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) + return + } + if err != wantErrno { + t.Errorf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno) + return + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 { + t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) + } + }() + + go func() { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 { + t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) + } + }() + + // TODO(b/155684889) This sleep is to allow time for the DUT to + // actually call recv since we want the ICMP error to arrive during the + // blocking recv, and should be replaced when a better synchronization + // alternative is available. + time.Sleep(2 * time.Second) + + if err := sendICMPError(&conn, icmpErr, udp); err != nil { + t.Fatal(err) + } + + conn.Send(testbench.UDP{DstPort: &cleanPort}) + conn.Send(testbench.UDP{}) + wg.Wait() + }) + } + } +} diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go new file mode 100644 index 000000000..77a9bfa1d --- /dev/null +++ b/test/packetimpact/tests/udp_recv_multicast_test.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_recv_multicast_test + +import ( + "flag" + "net" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestUDPRecvMulticast(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close() + conn.SendIP(testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))}, testbench.UDP{}) + dut.Recv(boundFD, 100, 0) +} diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go new file mode 100644 index 000000000..224feef85 --- /dev/null +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_send_recv_dgram_test + +import ( + "flag" + "net" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestUDPRecv(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close() + + testCases := []struct { + name string + payload []byte + }{ + {"emptypayload", nil}, + {"small payload", []byte("hello world")}, + {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, + // Even though UDP allows larger dgrams we don't test it here as + // they need to be fragmented and written out as individual + // frames. + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) + if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want { + t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want) + } + }) + } +} + +func TestUDPSend(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close() + + testCases := []struct { + name string + payload []byte + }{ + {"emptypayload", nil}, + {"small payload", []byte("hello world")}, + {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, + // Even though UDP allows larger dgrams we don't test it here as + // they need to be fragmented and written out as individual + // frames. + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conn.Drain() + if got, want := int(dut.SendTo(boundFD, tc.payload, 0, conn.LocalAddr())), len(tc.payload); got != want { + t.Fatalf("short write got: %d, want: %d", got, want) + } + if _, err := conn.ExpectData(testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, 1*time.Second); err != nil { + t.Fatal(err) + } + }) + } +} |