diff options
author | Googler <noreply@google.com> | 2018-04-27 10:37:02 -0700 |
---|---|---|
committer | Adin Scannell <ascannell@google.com> | 2018-04-28 01:44:26 -0400 |
commit | d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch) | |
tree | 54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/transport/tcp/testing | |
parent | f70210e742919f40aa2f0934a22f1c9ba6dada62 (diff) |
Check in gVisor.
PiperOrigin-RevId: 194583126
Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/transport/tcp/testing')
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/BUILD | 27 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 900 |
2 files changed, 927 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD new file mode 100644 index 000000000..40850c3e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -0,0 +1,27 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "context", + testonly = 1, + srcs = ["context.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context", + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/seqnum", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go new file mode 100644 index 000000000..6a402d150 --- /dev/null +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -0,0 +1,900 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context provides a test context for use in tcp tests. It also +// provides helper methods to assert/check certain behaviours. +package context + +import ( + "bytes" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/checker" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // StackAddr is the IPv4 address assigned to the stack. + StackAddr = "\x0a\x00\x00\x01" + + // StackPort is used as the listening port in tests for passive + // connects. + StackPort = 1234 + + // TestAddr is the source address for packets sent to the stack via the + // link layer endpoint. + TestAddr = "\x0a\x00\x00\x02" + + // TestPort is the TCP port used for packets sent to the stack + // via the link layer endpoint. + TestPort = 4096 + + // StackV6Addr is the IPv6 address assigned to the stack. + StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + + // TestV6Addr is the source address for packets sent to the stack via + // the link layer endpoint. + TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // StackV4MappedAddr is StackAddr as a mapped v6 address. + StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr + + // TestV4MappedAddr is TestAddr as a mapped v6 address. + TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr + + // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + // testInitialSequenceNumber is the initial sequence number sent in packets that + // are sent in response to a SYN or in the initial SYN sent to the stack. + testInitialSequenceNumber = 789 +) + +// defaultWindowScale value specified here depends on the tcp.DefaultBufferSize +// constant defined in the tcp/endpoint.go because the tcp.DefaultBufferSize is +// used in tcp.newHandshake to determine the window scale to use when sending a +// SYN/SYN-ACK. +var defaultWindowScale = tcp.FindWndScale(tcp.DefaultBufferSize) + +// Headers is used to represent the TCP header fields when building a +// new packet. +type Headers struct { + // SrcPort holds the src port value to be used in the packet. + SrcPort uint16 + + // DstPort holds the destination port value to be used in the packet. + DstPort uint16 + + // SeqNum is the value of the sequence number field in the TCP header. + SeqNum seqnum.Value + + // AckNum represents the acknowledgement number field in the TCP header. + AckNum seqnum.Value + + // Flags are the TCP flags in the TCP header. + Flags int + + // RcvWnd is the window to be advertised in the ReceiveWindow field of + // the TCP header. + RcvWnd seqnum.Size + + // TCPOpts holds the options to be sent in the option field of the TCP + // header. + TCPOpts []byte +} + +// Context provides an initialized Network stack and a link layer endpoint +// for use in TCP tests. +type Context struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack + + // IRS holds the initial sequence number in the SYN sent by endpoint in + // case of an active connect or the sequence number sent by the endpoint + // in the SYN-ACK sent in response to a SYN when listening in passive + // mode. + IRS seqnum.Value + + // Port holds the port bound by EP below in case of an active connect or + // the listening port number in case of a passive connect. + Port uint16 + + // EP is the test endpoint in the stack owned by this context. This endpoint + // is used in various tests to either initiate an active connect or is used + // as a passive listening endpoint to accept inbound connections. + EP tcpip.Endpoint + + // Wq is the wait queue associated with EP and is used to block for events + // on EP. + WQ waiter.Queue + + // TimeStampEnabled is true if ep is connected with the timestamp option + // enabled. + TimeStampEnabled bool +} + +// New allocates and initializes a test context containing a new +// stack and a link-layer endpoint. +func New(t *testing.T, mtu uint32) *Context { + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}) + + // Allow minimum send/receive buffer sizes to be 1 during tests. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + // Some of the congestion control tests send up to 640 packets, we so + // set the channel size to 1000. + id, linkEP := channel.New(1000, mtu, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + { + Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + }) + + return &Context{ + t: t, + s: s, + linkEP: linkEP, + } +} + +// Cleanup closes the context endpoint if required. +func (c *Context) Cleanup() { + if c.EP != nil { + c.EP.Close() + } +} + +// Stack returns a reference to the stack in the Context. +func (c *Context) Stack() *stack.Stack { + return c.s +} + +// CheckNoPacketTimeout verifies that no packet is received during the time +// specified by wait. +func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { + select { + case <-c.linkEP.C: + c.t.Fatalf(errMsg) + + case <-time.After(wait): + } +} + +// CheckNoPacket verifies that no packet is received for 1 second. +func (c *Context) CheckNoPacket(errMsg string) { + c.CheckNoPacketTimeout(errMsg, 1*time.Second) +} + +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. It will fail with an error if no packet is received for +// 2 seconds. +func (c *Context) GetPacket() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. +func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { + // Allocate a buffer data and headers. + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2)) + if len(buf) > maxTotalSize { + buf = buf[:maxTotalSize] + } + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: TestAddr, + DstAddr: StackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) + icmp.SetType(typ) + icmp.SetCode(code) + + copy(icmp[header.ICMPv4MinimumSize:], p1) + copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv4.ProtocolNumber, &vv) +} + +// SendPacket builds and sends a TCP segment(with the provided payload & TCP +// headers) in an IPv4 packet via the link layer endpoint. +func (c *Context) SendPacket(payload []byte, h *Headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(tcp.ProtocolNumber), + SrcAddr: TestAddr, + DstAddr: StackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv4MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.Checksum([]byte(TestAddr), 0) + xsum = header.Checksum([]byte(StackAddr), xsum) + xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum) + + // Calculate the TCP checksum and set it. + length := uint16(header.TCPMinimumSize + len(h.TCPOpts) + len(payload)) + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum, length)) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv4.ProtocolNumber, &vv) +} + +// SendAck sends an ACK packet. +func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(testInitialSequenceNumber).Add(1), + AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), + RcvWnd: 30000, + }) +} + +// ReceiveAndCheckPacket reads a packet from the link layer endpoint and +// verifies that the packet packet payload of packet matches the slice +// of data indicated by offset & size. +func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(TestPort), + checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[offset:][:size] + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { + c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } +} + +// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only +// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 +// only endpoint instead of a default dual stack socket. +func (c *Context) CreateV6Endpoint(v6only bool) { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.EP.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } +} + +// GetV6Packet reads a single packet from the link layer endpoint of the context +// and asserts that it is an IPv6 Packet with the expected src/dest addresses. +func (c *Context) GetV6Packet() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of +// the context. +func (c *Context) SendV6Packet(payload []byte, h *Headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + NextHeader: uint8(tcp.ProtocolNumber), + HopLimit: 65, + SrcAddr: TestV6Addr, + DstAddr: StackV6Addr, + }) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv6MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: header.TCPMinimumSize, + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.Checksum([]byte(TestV6Addr), 0) + xsum = header.Checksum([]byte(StackV6Addr), xsum) + xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum) + + // Calculate the TCP checksum and set it. + length := uint16(header.TCPMinimumSize + len(payload)) + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum, length)) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv6.ProtocolNumber, &vv) +} + +// CreateConnected creates a connected TCP endpoint. +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { + c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if epRcvBuf != nil { + if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) + if err != tcpip.ErrConnectStarted { + c.t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + c.SendPacket(nil, &Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Receive ACK packet. + checker.IPv4(c.t, c.GetPacket(), + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-notifyCh: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + + c.Port = tcp.SourcePort() +} + +// RawEndpoint is just a small wrapper around a TCP endpoint's state to make +// sending data and ACK packets easy while being able to manipulate the sequence +// numbers and timestamp values as needed. +type RawEndpoint struct { + C *Context + SrcPort uint16 + DstPort uint16 + Flags int + NextSeqNum seqnum.Value + AckNum seqnum.Value + WndSize seqnum.Size + RecentTS uint32 // Stores the latest timestamp to echo back. + TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. + + // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. + SACKPermitted bool +} + +// SendPacketWithTS embeds the provided tsVal in the Timestamp option +// for the packet to be sent out. +func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { + r.TSVal = tsVal + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) + r.SendPacket(payload, tsOpt[:]) +} + +// SendPacket is a small wrapper function to build and send packets. +func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { + packetHeaders := &Headers{ + SrcPort: r.SrcPort, + DstPort: r.DstPort, + Flags: r.Flags, + SeqNum: r.NextSeqNum, + AckNum: r.AckNum, + RcvWnd: r.WndSize, + TCPOpts: opts, + } + r.C.SendPacket(payload, packetHeaders) + r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPTimestampChecker(true, 0, tsVal), + ), + ) + // Store the parsed TSVal from the ack as recentTS. + tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) + opts := tcpSeg.ParsedOptions() + r.RecentTS = opts.TSVal +} + +// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. +func (r *RawEndpoint) VerifyACKNoSACK() { + r.VerifyACKHasSACK(nil) +} + +// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. +func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { + // Read ACK and verify that the TCP options in the segment do + // not contain a SACK block. + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSACKBlockChecker(sackBlocks), + ), + ) +} + +// CreateConnectedWithOptions creates and connects c.ep with the specified TCP +// options enabled and returns a RawEndpoint which represents the other end of +// the connection. +// +// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK +// does not carry an option that was not requested. +func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) + } + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} + err = c.EP.Connect(testFullAddr) + if err != tcpip.ErrConnectStarted { + c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) + } + // Receive SYN packet. + b := c.GetPacket() + // Validate that the syn has the timestamp option and a valid + // TS value. + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{ + MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize), + TS: true, + WS: defaultWindowScale, + SACKPermitted: c.SACKEnabled(), + }), + ), + ) + tcpSeg := header.TCP(header.IPv4(b).Payload()) + synOptions := header.ParseSynOptions(tcpSeg.Options(), false) + + // Build options w/ tsVal to be sent in the SYN-ACK. + synAckOptions := make([]byte, 40) + offset := 0 + if wantOptions.TS { + offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) + } + if wantOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) + } + + offset += header.AddTCPOptionPadding(synAckOptions, offset) + + // Build SYN-ACK. + c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + TCPOpts: synAckOptions[:offset], + }) + + // Read ACK. + ackPacket := c.GetPacket() + + // Verify TCP header fields. + tcpCheckers := []checker.TransportChecker{ + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS) + 1), + checker.AckNum(uint32(iss) + 1), + } + + // Verify that tsEcr of ACK packet is wantOptions.TSVal if the + // timestamp option was enabled, if not then we verify that + // there is no timestamp in the ACK packet. + if wantOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) + + ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) + ackOptions := ackSeg.ParsedOptions() + + // Wait for connection to be established. + select { + case <-notifyCh: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + + // Store the source port in use by the endpoint. + c.Port = tcpSeg.SourcePort() + + // Mark in context that timestamp option is enabled for this endpoint. + c.TimeStampEnabled = true + + return &RawEndpoint{ + C: c, + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagAck | header.TCPFlagPsh, + NextSeqNum: iss + 1, + AckNum: c.IRS.Add(1), + WndSize: 30000, + RecentTS: ackOptions.TSVal, + TSVal: wantOptions.TSVal, + SACKPermitted: wantOptions.SACKPermitted, + } +} + +// AcceptWithOptions initializes a listening endpoint and connects to it with the +// provided options enabled. It also verifies that the SYN-ACK has the expected +// values for the provided options. +// +// The function returns a RawEndpoint representing the other end of the accepted +// endpoint. +func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: StackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + c.t.Fatalf("Listen failed: %v", err) + } + + rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + c.t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for accept") + } + } + return rep +} + +// PassiveConnect just disables WindowScaling and delegates the call to +// PassiveConnectWithOptions. +func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { + synOptions.WS = -1 + c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) +} + +// PassiveConnectWithOptions initiates a new connection (with the specified TCP +// options enabled) to the port on which the Context.ep is listening for new +// connections. It also validates that the SYN-ACK has the expected values for +// the enabled options. +// +// NOTE: MSS is not a negotiated option and it can be asymmetric +// in each direction. This function uses the maxPayload to set the MSS to be +// sent to the peer on a connect and validates that the MSS in the SYN-ACK +// response is equal to the MTU - (tcphdr len + iphdr len). +// +// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the +// value of the window scaling option to be sent in the SYN. If synOptions.WS > +// 0 then we send the WindowScale option. +func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + opts := make([]byte, 40) + offset := 0 + offset += header.EncodeMSSOption(uint32(maxPayload), opts) + + if synOptions.WS >= 0 { + offset += header.EncodeWSOption(3, opts[offset:]) + } + if synOptions.TS { + offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) + } + + if synOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(opts[offset:]) + } + + paddingToAdd := 4 - offset%4 + // Now add any padding bytes that might be required to quad align the + // options. + for i := offset; i < offset+paddingToAdd; i++ { + opts[i] = header.TCPOptionNOP + } + offset += paddingToAdd + + // Send a SYN request. + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + TCPOpts: opts[:offset], + }) + + // Receive the SYN-ACK reply. Make sure MSS and other expected options + // are present. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(StackPort), + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(iss) + 1), + checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), + } + + // If TS option was enabled in the original SYN then add a checker to + // validate the Timestamp option in the SYN-ACK. + if synOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) + rcvWnd := seqnum.Size(30000) + ackHeaders := &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: rcvWnd, + } + + // If WS was expected to be in effect then scale the advertised window + // correspondingly. + if synOptions.WS > 0 { + ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) + } + + parsedOpts := tcp.ParsedOptions() + if synOptions.TS { + // Echo the tsVal back to the peer in the tsEcr field of the + // timestamp option. + // Increment TSVal by 1 from the value sent in the SYN and echo + // the TSVal in the SYN-ACK in the TSEcr field. + opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) + ackHeaders.TCPOpts = opts[:] + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + c.Port = StackPort + + return &RawEndpoint{ + C: c, + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagPsh | header.TCPFlagAck, + NextSeqNum: iss + 1, + AckNum: c.IRS + 1, + WndSize: rcvWnd, + SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), + RecentTS: parsedOpts.TSVal, + TSVal: synOptions.TSVal + 1, + } +} + +// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true +// for the Stack in the context. +func (c *Context) SACKEnabled() bool { + var v tcp.SACKEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return false + } + return bool(v) +} |