diff options
Diffstat (limited to 'pkg/tcpip/transport')
22 files changed, 1457 insertions, 1193 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 346ca4bda..41eb0ca44 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -74,6 +74,8 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -344,9 +346,14 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -415,8 +422,17 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -430,6 +446,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -462,6 +479,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -597,7 +615,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 74ef6541e..87d510f96 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -13,12 +13,7 @@ // limitations under the License. // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport -// protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and activated on the stack by passing -// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport -// protocols when calling stack.New(). Then endpoints can be created by passing -// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number -// when calling Stack.NewEndpoint(). +// protocols for use in ping. package icmp import ( @@ -42,6 +37,8 @@ const ( // protocol implements stack.TransportProtocol. type protocol struct { + stack *stack.Stack + number tcpip.TransportProtocolNumber } @@ -62,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue) + return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) + return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } // MinimumPacketSize returns the minimum valid icmp packet size. @@ -104,17 +101,17 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -135,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4() stack.TransportProtocol { - return &protocol{ProtocolNumber4} +func NewProtocol4(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6() stack.TransportProtocol { - return &protocol{ProtocolNumber6} +func NewProtocol6(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber6} } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 81093e9ca..072601d2d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -83,6 +83,8 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -192,13 +194,13 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes return ep.ReadPacket(addr, nil) } -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -210,25 +212,25 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes tcpip.ErrNotSupported. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrNotSupported } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { return tcpip.ErrNotSupported } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with // Accept, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -267,12 +269,12 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -298,10 +300,16 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + ep.mu.Lock() + ep.linger = *v + ep.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -366,12 +374,21 @@ func (ep *endpoint) LastError() *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + ep.mu.Lock() + *o = ep.linger + ep.mu.Unlock() + return nil + + default: + return tcpip.ErrNotSupported + } } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { +func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrNotSupported } @@ -508,7 +525,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (*endpoint) State() uint32 { return 0 } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 71feeb748..e37c00523 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -84,6 +84,8 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -446,12 +448,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -482,12 +484,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -511,10 +513,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -577,8 +585,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 234fb95ce..518449602 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -69,6 +69,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", @@ -93,6 +94,7 @@ go_test( shard_count = 10, deps = [ ":tcp", + "//pkg/rand", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 72df5c2a1..6891fd245 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -522,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled SACKEnabled + var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. @@ -747,6 +747,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) + pkt.TransportProtocolNumber = header.TCPProtocolNumber tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -897,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -1002,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - rcvBufSize := seqnum.Size(e.receiveBufferSize()) e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. @@ -1135,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { } cont, err := e.handleSegment(s) + s.decRef() if err != nil { - s.decRef() return err } if !cont { - s.decRef() return nil } } @@ -1220,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } + // Increase counter if after processing the segment we would potentially + // advertise a zero window. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. @@ -1232,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // or a notification from the protocolMainLoop (caller goroutine). // This means that with this return, the segment dequeue below can // never occur on a closed endpoint. - s.decRef() return false, nil } @@ -1424,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.rcv.nonZeroWindow() } - if n¬ifyReceiveWindowChanged != 0 { - e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) - } - if n¬ifyMTUChanged != 0 { e.sndBufMu.Lock() count := e.packetTooBigCount diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 6074cc24e..560b4904c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -78,8 +78,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv4(t, c.GetPacket(), ackCheckers...) @@ -185,8 +185,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv6(t, c.GetV6Packet(), ackCheckers...) @@ -283,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -319,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -352,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -371,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -492,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -510,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() + var addr tcpip.FullAddress + nep, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -526,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) { } } + if addr.Addr != context.TestV6Addr { + t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) + } + // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) } } @@ -566,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } const n = uint16(32) @@ -610,12 +605,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c5d9eba5d..ae817091a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,6 +63,17 @@ const ( StateClosing ) +const ( + // rcvAdvWndScale is used to split the available socket buffer into + // application buffer and the window to be advertised to the peer. This is + // currently hard coded to split the available space equally. + rcvAdvWndScale = 1 + + // SegOverheadFactor is used to multiply the value provided by the + // user on a SetSockOpt for setting the socket send/receive buffer sizes. + SegOverheadFactor = 2 +) + // connected returns true when s is one of the states representing an // endpoint connected to a peer. func (s EndpointState) connected() bool { @@ -149,7 +160,6 @@ func (s EndpointState) String() string { // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota - notifyReceiveWindowChanged notifyClose notifyMTUChanged notifyDrain @@ -384,13 +394,26 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + // rcvBufSize is the total size of the receive buffer. + rcvBufSize int + // rcvBufUsed is the actual number of payload bytes held in the receive buffer + // not counting any overheads of the segments itself. NOTE: This will always + // be strictly <= rcvMemUsed below. rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams + // rcvMemUsed tracks the total amount of memory in use by received segments + // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // compute the window and the actual available buffer space. This is distinct + // from rcvBufUsed above which is the actual number of payload bytes held in + // the buffer not including any segment overheads. + // + // rcvMemUsed must be accessed atomically. + rcvMemUsed int32 + // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. mu sync.Mutex `state:"nosave"` @@ -852,12 +875,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -867,12 +890,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.cc = cs } - var mrb tcpip.ModerateReceiveBufferOption + var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - var de DelayEnabled + var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.probe = p } - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) @@ -904,7 +927,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) switch e.EndpointState() { - case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: + case StateInitial, StateBound: + // This prevents blocking of new sockets which are not + // connected when SO_LINGER is set. + result |= waiter.EventHUp + + case StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. case StateClose, StateError, StateTimeWait: @@ -1075,6 +1103,8 @@ func (e *endpoint) closeNoShutdownLocked() { e.notifyProtocolGoroutine(notifyClose) } else { e.transitionToStateCloseLocked() + // Notify that the endpoint is closed. + e.waiterQueue.Notify(waiter.EventHUp) } } @@ -1129,10 +1159,16 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// wndFromSpace returns the window that we can advertise based on the available +// receive buffer space. +func wndFromSpace(space int) int { + return space / (1 << rcvAdvWndScale) +} + // initialReceiveWindow returns the initial receive window to advertise in the // SYN/SYN-ACK. func (e *endpoint) initialReceiveWindow() int { - rcvWnd := e.receiveBufferAvailable() + rcvWnd := wndFromSpace(e.receiveBufferAvailable()) if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } @@ -1209,14 +1245,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = rcvWnd - availAfter := e.receiveBufferAvailableLocked() - mask := uint32(notifyReceiveWindowChanged) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } - e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -1293,18 +1327,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { v := views[s.viewToDeliver] s.viewToDeliver++ + var delta int if s.viewToDeliver >= len(views) { e.rcvList.Remove(s) + // We only free up receive buffer space when the segment is released as the + // segment is still holding on to the views even though some views have been + // read out to the user. + delta = s.segMemSize() s.decRef() } e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1317,14 +1355,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { - // The endpoint cannot be written to if it's not connected. - if !e.EndpointState().connected() { - switch e.EndpointState() { - case StateError: - return 0, e.HardError - default: - return 0, tcpip.ErrClosedForSend - } + switch s := e.EndpointState(); { + case s == StateError: + return 0, e.HardError + case !s.connecting() && !s.connected(): + return 0, tcpip.ErrClosedForSend + case s.connecting(): + // As per RFC793, page 56, a send request arriving when in connecting + // state, can be queued to be completed after the state becomes + // connected. Return an error code for the caller of endpoint Write to + // try again, until the connection handshake is complete. + return 0, tcpip.ErrWouldBlock } // Check if the connection has already been closed for sends. @@ -1478,11 +1519,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } // windowCrossedACKThresholdLocked checks if the receive window to be announced -// now would be under aMSS or under half receive buffer, whichever smaller. This -// is useful as a receive side silly window syndrome prevention mechanism. If -// window grows to reasonable value, we should send ACK to the sender to inform -// the rx space is now large. We also want ensure a series of small read()'s -// won't trigger a flood of spurious tiny ACK's. +// would be under aMSS or under the window derived from half receive buffer, +// whichever smaller. This is useful as a receive side silly window syndrome +// prevention mechanism. If window grows to reasonable value, we should send ACK +// to the sender to inform the rx space is now large. We also want ensure a +// series of small read()'s won't trigger a flood of spurious tiny ACK's. // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of @@ -1493,17 +1534,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := e.receiveBufferAvailableLocked() + newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 } - threshold := int(e.amss) - if threshold > e.rcvBufSize/2 { - threshold = e.rcvBufSize / 2 + // rcvBufFraction is the inverse of the fraction of receive buffer size that + // is used to decide if the available buffer space is now above it. + const rcvBufFraction = 2 + if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + threshold = wndThreshold } - switch { case oldAvail < threshold && newAvail >= threshold: return true, true @@ -1632,18 +1674,24 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs ReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + var rs tcpip.TCPReceiveBufferSizeRangeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) + } + + if v > rs.Max { + v = rs.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < rs.Min { v = rs.Min } - if v > rs.Max { - v = rs.Max - } + } else { + v = math.MaxInt32 } - mask := uint32(notifyReceiveWindowChanged) - e.LockUser() e.rcvListMu.Lock() @@ -1657,14 +1705,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { v = 1 << scale } - // Make sure 2*size doesn't overflow. - if v > math.MaxInt32/2 { - v = math.MaxInt32 / 2 - } - - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = v - availAfter := e.receiveBufferAvailableLocked() + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvAutoParams.disabled = true @@ -1672,24 +1715,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } e.rcvListMu.Unlock() e.UnlockUser() - e.notifyProtocolGoroutine(mask) case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss SendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) + } + + if v > ss.Max { + v = ss.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < ss.Min { v = ss.Min } - if v > ss.Max { - v = ss.Max - } + } else { + v = math.MaxInt32 } e.sndBufMu.Lock() @@ -1722,7 +1772,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -1771,7 +1821,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // Query the available cc algorithms in the stack and // validate that the specified algorithm is actually // supported in the stack. - var avail tcpip.AvailableCongestionControlOption + var avail tcpip.TCPAvailableCongestionControlOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { return err } @@ -2047,8 +2097,10 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { e.UnlockUser() case *tcpip.OriginalDestinationOption: + e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID) + addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) + e.UnlockUser() if err != nil { return err } @@ -2486,7 +2538,9 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +// +// addr if not-nil will contain the peer address of the returned endpoint. +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2508,6 +2562,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } + if peerAddr != nil { + *peerAddr = n.getRemoteAddress() + } return n, n.waiterQueue, nil } @@ -2610,11 +2667,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotConnected } + return e.getRemoteAddress(), nil +} + +func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort, NIC: e.boundNICID, - }, nil + } } func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -2685,13 +2746,8 @@ func (e *endpoint) updateSndBufferUsage(v int) { func (e *endpoint) readyToRead(s *segment) { e.rcvListMu.Lock() if s != nil { + e.rcvBufUsed += s.payloadSize() s.incRef() - e.rcvBufUsed += s.data.Size() - // Increase counter if the receive window falls down below MSS - // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } e.rcvList.PushBack(s) } else { e.rcvClosed = true @@ -2706,15 +2762,17 @@ func (e *endpoint) readyToRead(s *segment) { func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. - if e.rcvBufUsed >= e.rcvBufSize { + memUsed := e.receiveMemUsed() + if memUsed >= e.rcvBufSize { return 0 } - return e.rcvBufSize - e.rcvBufUsed + return e.rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the -// receive buffer. +// receive buffer based on the actual memory used by all segments held in +// receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvListMu.Lock() available := e.receiveBufferAvailableLocked() @@ -2722,16 +2780,37 @@ func (e *endpoint) receiveBufferAvailable() int { return available } +// receiveBufferUsed returns the amount of in-use receive buffer. +func (e *endpoint) receiveBufferUsed() int { + e.rcvListMu.Lock() + used := e.rcvBufUsed + e.rcvListMu.Unlock() + return used +} + +// receiveBufferSize returns the current size of the receive buffer. func (e *endpoint) receiveBufferSize() int { e.rcvListMu.Lock() size := e.rcvBufSize e.rcvListMu.Unlock() - return size } +// receiveMemUsed returns the total memory in use by segments held by this +// endpoint. +func (e *endpoint) receiveMemUsed() int { + return int(atomic.LoadInt32(&e.rcvMemUsed)) +} + +// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed. +func (e *endpoint) updateReceiveMemUsed(delta int) { + atomic.AddInt32(&e.rcvMemUsed, int32(delta)) +} + +// maxReceiveBufferSize returns the stack wide maximum receive buffer size for +// an endpoint. func (e *endpoint) maxReceiveBufferSize() int { - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2811,7 +2890,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v SACKEnabled + var v tcpip.TCPSACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return @@ -2880,7 +2959,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RcvAcc: e.rcv.rcvAcc, RcvWndScale: e.rcv.rcvWndScale, PendingBufUsed: e.rcv.pendingBufUsed, - PendingBufSize: e.rcv.pendingBufSize, } // Copy sender state. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 723e47ddc..b25431467 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() { // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. - e.segmentQueue.setLimit(0) + e.segmentQueue.freeze() e.mu.Lock() defer e.mu.Unlock() @@ -178,18 +178,18 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c5afa2680..5bce73605 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package tcp contains the implementation of the TCP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing tcp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package tcp contains the implementation of the TCP transport protocol. package tcp import ( @@ -29,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -79,50 +75,6 @@ const ( ccCubic = "cubic" ) -// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -// Recovery is used by stack.(*Stack).TransportProtocolOption to -// set loss detection algorithm in TCP. -type Recovery int32 - -const ( - // RACKLossDetection indicates RACK is used for loss detection and - // recovery. - RACKLossDetection Recovery = 1 << iota - - // RACKStaticReoWnd indicates the reordering window should not be - // adjusted when DSACK is received. - RACKStaticReoWnd - - // RACKNoDupTh indicates RACK should not consider the classic three - // duplicate acknowledgements rule to mark the segments as lost. This - // is used when reordering is not detected. - RACKNoDupTh -) - -// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type DelayEnabled bool - -// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max TCP send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - -// ReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// TCP receive buffer sizes. -type ReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -181,12 +133,14 @@ func (s *synRcvdCounter) Threshold() uint64 { } type protocol struct { + stack *stack.Stack + mu sync.RWMutex sackEnabled bool - recovery Recovery + recovery tcpip.TCPRecovery delayEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + sendBufferSize tcpip.TCPSendBufferSizeRangeOption + recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -207,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid tcp packet size. @@ -244,21 +198,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { - return false + return stack.UnknownDestinationPacketMalformed } - // There's nothing to do if this is already a reset packet. - if s.flagIsSet(header.TCPFlagRst) { - return true + if !s.flagIsSet(header.TCPFlagRst) { + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) } - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) - return true + return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. @@ -296,49 +249,49 @@ func replyWithReset(s *segment, tos, ttl uint8) { } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.Lock() - p.sackEnabled = bool(v) + p.sackEnabled = bool(*v) p.mu.Unlock() return nil - case Recovery: + case *tcpip.TCPRecovery: p.mu.Lock() - p.recovery = Recovery(v) + p.recovery = *v p.mu.Unlock() return nil - case DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.Lock() - p.delayEnabled = bool(v) + p.delayEnabled = bool(*v) p.mu.Unlock() return nil - case SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.sendBufferSize = v + p.sendBufferSize = *v p.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.recvBufferSize = v + p.recvBufferSize = *v p.mu.Unlock() return nil - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { - if string(v) == c { + if string(*v) == c { p.mu.Lock() - p.congestionControl = string(v) + p.congestionControl = string(*v) p.mu.Unlock() return nil } @@ -347,75 +300,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // is specified. return tcpip.ErrNoSuchFile - case tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() - p.moderateReceiveBuffer = bool(v) + p.moderateReceiveBuffer = bool(*v) p.mu.Unlock() return nil - case tcpip.TCPLingerTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPLingerTimeoutOption: p.mu.Lock() - p.lingerTimeout = time.Duration(v) + if *v < 0 { + p.lingerTimeout = 0 + } else { + p.lingerTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPTimeWaitTimeoutOption: p.mu.Lock() - p.timeWaitTimeout = time.Duration(v) + if *v < 0 { + p.timeWaitTimeout = 0 + } else { + p.timeWaitTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitReuseOption: - if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + case *tcpip.TCPTimeWaitReuseOption: + if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.timeWaitReuse = v + p.timeWaitReuse = *v p.mu.Unlock() return nil - case tcpip.TCPMinRTOOption: - if v < 0 { - v = tcpip.TCPMinRTOOption(MinRTO) - } + case *tcpip.TCPMinRTOOption: p.mu.Lock() - p.minRTO = time.Duration(v) + if *v < 0 { + p.minRTO = MinRTO + } else { + p.minRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRTOOption: - if v < 0 { - v = tcpip.TCPMaxRTOOption(MaxRTO) - } + case *tcpip.TCPMaxRTOOption: p.mu.Lock() - p.maxRTO = time.Duration(v) + if *v < 0 { + p.maxRTO = MaxRTO + } else { + p.maxRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRetriesOption: + case *tcpip.TCPMaxRetriesOption: p.mu.Lock() - p.maxRetries = uint32(v) + p.maxRetries = uint32(*v) p.mu.Unlock() return nil - case tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(v)) + p.synRcvdCount.SetThreshold(uint64(*v)) p.mu.Unlock() return nil - case tcpip.TCPSynRetriesOption: - if v < 1 || v > 255 { + case *tcpip.TCPSynRetriesOption: + if *v < 1 || *v > 255 { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.synRetries = uint8(v) + p.synRetries = uint8(*v) p.mu.Unlock() return nil @@ -425,33 +382,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.RLock() - *v = SACKEnabled(p.sackEnabled) + *v = tcpip.TCPSACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *Recovery: + case *tcpip.TCPRecovery: p.mu.RLock() - *v = Recovery(p.recovery) + *v = tcpip.TCPRecovery(p.recovery) p.mu.RUnlock() return nil - case *DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.RLock() - *v = DelayEnabled(p.delayEnabled) + *v = tcpip.TCPDelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -463,15 +420,15 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil - case *tcpip.AvailableCongestionControlOption: + case *tcpip.TCPAvailableCongestionControlOption: p.mu.RLock() - *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.RUnlock() return nil - case *tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.RLock() - *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer) p.mu.RUnlock() return nil @@ -546,33 +503,19 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TCP header is variable length, peek at it first. - hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) - if !ok { - return false - } - - // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of - // packets. - hdrLen = offset - } - - _, ok = pkt.TransportHeader().Consume(hdrLen) - return ok + return parse.TCP(pkt) } // NewProtocol returns a TCP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(s *stack.Stack) stack.TransportProtocol { p := protocol{ - sendBufferSize: SendBufferSizeOption{ + stack: s, + sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: ReceiveBufferSizeOption{ + recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, @@ -587,7 +530,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, - recovery: RACKLossDetection, + recovery: tcpip.TCPRACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 5e0bfe585..48bf196d8 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -47,22 +47,24 @@ type receiver struct { closed bool + // pendingRcvdSegments is bounded by the receive buffer size of the + // endpoint. pendingRcvdSegments segmentHeap - pendingBufUsed seqnum.Size - pendingBufSize seqnum.Size + // pendingBufUsed tracks the total number of bytes (including segment + // overhead) currently queued in pendingRcvdSegments. + pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` } -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ ep: ep, rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, lastRcvdAckTime: time.Now(), } } @@ -85,15 +87,30 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - // Calculate the window size based on the available buffer space. - receiveBufferAvailable := r.ep.receiveBufferAvailable() - acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) - if r.rcvAcc.LessThan(acc) { - r.rcvAcc = acc + avail := wndFromSpace(r.ep.receiveBufferAvailable()) + if avail == 0 { + // We have no space available to accept any data, move to zero window + // state. + r.rcvWnd = 0 + return r.rcvNxt, 0 + } + + acc := r.rcvNxt.Add(seqnum.Size(avail)) + newWnd := r.rcvNxt.Size(acc) + curWnd := r.rcvNxt.Size(r.rcvAcc) + + // Update rcvAcc only if new window is > previously advertised window. We + // should never shrink the acceptable sequence space once it has been + // advertised the peer. If we shrink the acceptable sequence space then we + // would end up dropping bytes that might already be in flight. + if newWnd > curWnd { + r.rcvAcc = r.rcvNxt.Add(newWnd) + } else { + newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. - r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + r.rcvWnd = newWnd return r.rcvNxt, r.rcvWnd >> r.rcvWndScale } @@ -195,7 +212,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of // truncated items, thus truncated items must be set to nil to avoid // memory leaks. @@ -268,14 +287,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { - case StateCloseWait: - // If the ACK acks something not yet sent then we send an ACK. - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { - r.ep.snd.sendAck() - return true, nil - } - fallthrough - case StateClosing, StateLastAck: + case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { // Just drop the segment as we have // already received a FIN and this @@ -284,9 +296,31 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo return true, nil } fallthrough - case StateFinWait1: - fallthrough - case StateFinWait2: + case StateFinWait1, StateFinWait2: + // If the ACK acks something not yet sent then we send an ACK. + // + // RFC793, page 37: If the connection is in a synchronized state, + // (ESTABLISHED, FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, + // TIME-WAIT), any unacceptable segment (out of window sequence number + // or unacceptable acknowledgment number) must elicit only an empty + // acknowledgment segment containing the current send-sequence number + // and an acknowledgment indicating the next sequence number expected + // to be received, and the connection remains in the same state. + // + // Just as on Linux, we do not apply this behavior when state is + // ESTABLISHED. + // Linux receive processing for all states except ESTABLISHED and + // TIME_WAIT is here where if the ACK check fails, we attempt to + // reply back with an ACK with correct seq/ack numbers. + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L6186 + // The ESTABLISHED state processing is here where if the ACK check + // fails, we ignore the packet: + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., // SHUT_RD) then any data past the rcvNxt should @@ -369,10 +403,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { - // We only store the segment if it's within our buffer - // size limit. - if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += seqnum.Size(s.segMemSize()) + // We only store the segment if it's within our buffer size limit. + // + // Only use 75% of the receive buffer queue for out-of-order + // segments. This ensures that we always leave some space for the inorder + // segments to arrive allowing pending segments to be processed and + // delivered to the user. + if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvListMu.Lock() + r.pendingBufUsed += s.segMemSize() + r.ep.rcvListMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -406,7 +446,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= seqnum.Size(s.segMemSize()) + r.ep.rcvListMu.Lock() + r.pendingBufUsed -= s.segMemSize() + r.ep.rcvListMu.Unlock() s.decRef() } return false, nil @@ -421,6 +463,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // Just silently drop any RST packets in TIME_WAIT. We do not support // TIME_WAIT assasination as a result we confirm w/ fix 1 as described // in https://tools.ietf.org/html/rfc1337#section-3. + // + // This behavior overrides RFC793 page 70 where we transition to CLOSED + // on receiving RST, which is also default Linux behavior. + // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337. + // + // As we do not yet support PAWS, we are being conservative in ignoring + // RSTs by default. if s.flagIsSet(header.TCPFlagRst) { return false, false } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 94307d31a..13acaf753 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "sync/atomic" "time" @@ -24,6 +25,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// queueFlags are used to indicate which queue of an endpoint a particular segment +// belongs to. This is used to track memory accounting correctly. +type queueFlags uint8 + +const ( + recvQ queueFlags = 1 << iota + sendQ +) + // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. // segment is mostly immutable, the only field allowed to change is viewToDeliver. @@ -32,6 +42,8 @@ import ( type segment struct { segmentEntry refCnt int32 + ep *endpoint + qFlags queueFlags id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -100,6 +112,8 @@ func (s *segment) clone() *segment { rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, + ep: s.ep, + qFlags: s.qFlags, } t.data = s.data.Clone(t.views[:]) return t @@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool { return s.flags&flags == flags } +// setOwner sets the owning endpoint for this segment. Its required +// to be called to ensure memory accounting for receive/send buffer +// queues is done properly. +func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) { + switch qFlags { + case recvQ: + ep.updateReceiveMemUsed(s.segMemSize()) + case sendQ: + // no memory account for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b", qFlags)) + } + s.ep = ep + s.qFlags = qFlags +} + func (s *segment) decRef() { if atomic.AddInt32(&s.refCnt, -1) == 0 { + if s.ep != nil { + switch s.qFlags { + case recvQ: + s.ep.updateReceiveMemUsed(-s.segMemSize()) + case sendQ: + // no memory accounting for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) + } + } s.route.Release() } } @@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// payloadSize is the size of s.data. +func (s *segment) payloadSize() int { + return s.data.Size() +} + // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 48a257137..54545a1b1 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -22,16 +22,16 @@ import ( // // +stateify savable type segmentQueue struct { - mu sync.Mutex `state:"nosave"` - list segmentList `state:"wait"` - limit int - used int + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + ep *endpoint + frozen bool } // emptyLocked determines if the queue is empty. // Preconditions: q.mu must be held. func (q *segmentQueue) emptyLocked() bool { - return q.used == 0 + return q.list.Empty() } // empty determines if the queue is empty. @@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool { return r } -// setLimit updates the limit. No segments are immediately dropped in case the -// queue becomes full due to the new limit. -func (q *segmentQueue) setLimit(limit int) { - q.mu.Lock() - q.limit = limit - q.mu.Unlock() -} - // enqueue adds the given segment to the queue. // // Returns true when the segment is successfully added to the queue, in which @@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) { // false if the queue is full, in which case ownership is retained by the // caller. func (q *segmentQueue) enqueue(s *segment) bool { + // q.ep.receiveBufferParams() must be called without holding q.mu to + // avoid lock order inversion. + bufSz := q.ep.receiveBufferSize() + used := q.ep.receiveMemUsed() q.mu.Lock() - r := q.used < q.limit - if r { + // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue + // is currently full). + allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + + if allow { q.list.PushBack(s) - q.used++ + // Set the owner now that the endpoint owns the segment. + s.setOwner(q.ep, recvQ) } q.mu.Unlock() - return r + return allow } // dequeue removes and returns the next segment from queue, if one exists. @@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment { s := q.list.Front() if s != nil { q.list.Remove(s) - q.used-- } q.mu.Unlock() return s } + +// freeze prevents any more segments from being added to the queue. i.e all +// future segmentQueue.enqueue will return false and not add the segment to the +// queue till the queue is unfroze with a corresponding segmentQueue.thaw call. +func (q *segmentQueue) freeze() { + q.mu.Lock() + q.frozen = true + q.mu.Unlock() +} + +// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and +// allows new segments to be queued again. +func (q *segmentQueue) thaw() { + q.mu.Lock() + q.frozen = false + q.mu.Unlock() +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 99521f0c1..ef7f5719f 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) + opt := tcpip.TCPSACKEnabled(enable) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index adb32e428..5b504d0d1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -240,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) { } } +// TestTCPResetsSentNoICMP confirms that we don't get an ICMP +// DstUnreachable packet when we try send a packet which is not part +// of an active session. +func TestTCPResetsSentNoICMP(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + + // Send a SYN request for a closed port. This should elicit an RST + // but NOT an ICMPv4 DstUnreachable packet. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive whatever comes back. + b := c.GetPacket() + ipHdr := header.IPv4(b) + if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { + t.Errorf("unexpected protocol, got = %d, want = %d", got, want) + } + + // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. + sent := stats.ICMP.V4PacketsSent + if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { + t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) + } +} + // TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates // a RST if an ACK is received on the listening socket for which there is no // active handshake in progress and we are not using SYN cookies. @@ -291,12 +324,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -309,16 +342,16 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Lower stackwide TIME_WAIT timeout so that the reservations // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) } c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -348,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -432,8 +465,9 @@ func TestConnectResetAfterClose(t *testing.T) { // Set TCPLinger to 3 seconds so that sockets are marked closed // after 3 second in FIN_WAIT2 state. tcpLingerTimeout := 3 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) + opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -446,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -488,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence number // of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -506,8 +540,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -527,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -563,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -610,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -631,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -691,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -743,8 +778,8 @@ func TestSimpleReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -933,8 +968,8 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { // Set the SynRcvd threshold to force a syn cookie based accept to happen. opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { @@ -996,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxSynAckV6(t *testing.T) { @@ -1024,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, @@ -1061,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, @@ -1099,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestSendRstOnListenerRxAckV4(t *testing.T) { @@ -1128,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxAckV6(t *testing.T) { @@ -1156,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestListenShutdown tests for the listening endpoint replying with RST @@ -1272,8 +1307,8 @@ func TestTOSV4(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1321,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1512,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1563,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1574,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Create a new connection with initial window size of 10. - c.CreateConnected(789, 30000, 10) + rcvBufSz := math.MaxUint16 + c.CreateConnected(789, 30000, rcvBufSz) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) @@ -1596,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1617,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1637,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(793), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1679,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1694,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1748,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1762,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { @@ -1781,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence // number of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), + checker.TCPSeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1827,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, 10) + const rcvBufSz = 10 + c.CreateConnected(789, 30000, rcvBufSz) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1838,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("Read failed: %s", err) } - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies + // the provided buffer value by tcp.SegOverheadFactor to calculate the actual + // receive buffer size. + data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) + for i := range data { + data[i] = byte(i % 255) + } c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -1860,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) @@ -1886,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), + checker.TCPWindow(10), ), ) } @@ -1898,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Start off with a window size of 10, then shrink it to 5. - c.CreateConnected(789, 30000, 10) - - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + // Start off with a certain receive buffer then cut it in half and verify that + // the right edge of the window does not shrink. + // NOTE: Netstack doubles the value specified here. + rcvBufSize := 65536 + iss := seqnum.Value(789) + // Enable window scaling with a scale of zero from our end. + c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1912,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ + // Send a 1 byte payload so that we can record the current receive window. + // Send a payload of half the size of rcvBufSize. + seqNum := iss.Add(1) + payload := []byte{1} + c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1931,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("Timed out waiting for data to arrive") } - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), + // Read the 1 byte payload we just sent. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + if got, want := payload, v; !bytes.Equal(got, want) { + t.Fatalf("got data: %v, want: %v", got, want) + } + + seqNum = seqNum.Add(1) + // Verify that the ACK does not shrink the window. + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), ), ) + // Stash the initial window. + initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) + // Now shrink the receive buffer to half its original size. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) + } - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ + data := generateRandomPayload(t, rcvBufSize) + // Send a payload of half the size of rcvBufSize. + c.SendPacket(data[:rcvBufSize/2], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") + // Verify that the ACK does not shrink the window. + pkt = c.GetPacket() + checker.IPv4(t, pkt, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) + if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { + t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) } + // Send another payload of half the size of rcvBufSize. This should fill up the + // socket receive buffer and we should see a zero window. + c.SendPacket(data[rcvBufSize/2:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + // Receive data and check it. - read := make([]byte, 0, 10) + read := make([]byte, 0, rcvBufSize) for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { @@ -1984,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("got data = %v, want = %v", read, data) } - // Check that we get an ACK for the newly non-zero window, which is the - // new size. + // Check that we get an ACK for the newly non-zero window, which is the new + // receive buffer size we set after the connection was established. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), + checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), ) } @@ -2017,8 +2109,8 @@ func TestSimpleSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2059,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2081,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2121,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2160,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2196,19 +2288,20 @@ func TestScaledWindowAccept(t *testing.T) { } // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2226,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2277,12 +2370,12 @@ func TestNonScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2307,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2322,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ + // Set the buffer size such that a window scale of 5 will be used. + const bufSz = 65535 * 10 + const ws = uint32(5) + c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) // Write chunks of 50000 bytes. - remain := wnd + remain := 0 sent := 0 data := make([]byte, 50000) - for remain > len(data) { + // Keep writing till the window drops below len(data). + for { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -2343,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Don't reduce window to zero here. + if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { + remain = wnd << ws + break + } } // Make the window non-zero, but the scaled window zero. - if remain >= 16 { + for remain >= 16 { data = data[:remain-15] c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, @@ -2368,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Since the receive buffer is split between window advertisement and + // application data buffer the window does not always reflect the space + // available and actual space available can be a bit more than what is + // advertised in the window. + wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) + if wnd == 0 { + break + } + remain = wnd << ws } - // Read at least 1MSS of data. An ack should be sent in response to that. + // Read at least 2MSS of data. An ack should be sent in response to that. + // Since buffer space is now split in half between window and application + // data we need to read more than 1 MSS(65536) of data for a non-zero window + // update to be sent. For 1MSS worth of window to be available we need to + // read at least 128KB. Since our segments above were 50KB each it means + // we need to read at 3 packets. sz := 0 - for sz < defaultMTU { + for sz < defaultMTU*2 { v, _, err := c.EP.Read(nil) if err != nil { t.Fatalf("Read failed: %s", err) @@ -2395,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(sz>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2464,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize+1), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+uint32(i)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2487,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+11), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+11), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2535,8 +2646,8 @@ func TestDelay(t *testing.T) { checker.PayloadLen(len(want)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2582,8 +2693,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2605,8 +2716,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2667,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2719,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2840,12 +2951,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2867,8 +2978,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(0) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -2895,12 +3007,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2961,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - const wndScale = 2 + const wndScale = 3 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } @@ -2996,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), + checker.TCPSeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -3017,8 +3129,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -3146,8 +3258,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { defer c.Cleanup() const numRetries = 2 - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { - t.Fatalf("could not set protocol option MaxRetries.\n") + opt := tcpip.TCPMaxRetriesOption(numRetries) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3206,8 +3319,9 @@ func TestMaxRTO(t *testing.T) { defer c.Cleanup() rto := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + opt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3309,8 +3423,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3330,8 +3444,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3352,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3363,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3384,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3408,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3433,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3455,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3483,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3502,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3522,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3543,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3567,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3592,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3608,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3629,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3654,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3675,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3690,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3706,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3798,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.PayloadLen((1<<scale)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3937,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3964,8 +4078,9 @@ func TestReadAfterClosedState(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -3987,8 +4102,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -4012,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(uint32(791+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(791+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4185,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func TestDefaultBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4204,11 +4319,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4221,11 +4340,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4240,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) { func TestMinMaxBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4252,22 +4375,28 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - // Set values below the min. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { + // Set values below the min/2. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } @@ -4278,19 +4407,21 @@ func TestMinMaxBufferSizes(t *testing.T) { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) + // Values above max are capped at max and then doubled. + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) + // Values above max are capped at max and then doubled. + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -4339,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) { func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) id := loopback.New() @@ -4626,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) { checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(seqNum), - checker.AckNum(790), + checker.TCPSeqNum(seqNum), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4718,8 +4849,8 @@ func TestStackSetCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption @@ -4751,12 +4882,12 @@ func TestStackAvailableCongestionControl(t *testing.T) { s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption + var aCC tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) } } @@ -4767,18 +4898,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") + aCC := tcpip.TCPAvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption + var cc tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) } } @@ -4842,8 +4973,8 @@ func TestEndpointSetCongestionControl(t *testing.T) { func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -4878,8 +5009,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4912,8 +5043,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4924,8 +5055,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), ), ) @@ -4950,8 +5081,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next-1)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(next-1)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4977,8 +5108,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -5018,7 +5149,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -5062,7 +5193,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -5135,12 +5266,12 @@ func TestListenBacklogFull(t *testing.T) { defer c.WQ.EventUnregister(&we) for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5152,7 +5283,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5164,12 +5295,12 @@ func TestListenBacklogFull(t *testing.T) { // Now a new handshake must succeed. executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5194,6 +5325,8 @@ func TestListenBacklogFull(t *testing.T) { func TestListenNoAcceptNonUnicastV4(t *testing.T) { multicastAddr := tcpip.Address("\xe0\x00\x01\x02") otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + subnet := context.StackAddrWithPrefix.Subnet() + subnetBroadcastAddr := subnet.Broadcast() tests := []struct { name string @@ -5201,53 +5334,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { dstAddr tcpip.Address }{ { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, + name: "SourceUnspecified", + srcAddr: header.IPv4Any, + dstAddr: context.StackAddr, }, { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, + name: "SourceBroadcast", + srcAddr: header.IPv4Broadcast, + dstAddr: context.StackAddr, }, { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, + name: "SourceOurMulticast", + srcAddr: multicastAddr, + dstAddr: context.StackAddr, }, { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, + name: "SourceOtherMulticast", + srcAddr: otherMulticastAddr, + dstAddr: context.StackAddr, }, { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, + name: "DestUnspecified", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Any, }, { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, + name: "DestBroadcast", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Broadcast, }, { - "DestOurMulticast", - context.TestAddr, - multicastAddr, + name: "DestOurMulticast", + srcAddr: context.TestAddr, + dstAddr: multicastAddr, }, { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, + name: "DestOtherMulticast", + srcAddr: context.TestAddr, + dstAddr: otherMulticastAddr, + }, + { + name: "SrcSubnetBroadcast", + srcAddr: subnetBroadcastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestSubnetBroadcast", + srcAddr: context.TestAddr, + dstAddr: subnetBroadcastAddr, }, } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5288,7 +5427,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5296,8 +5435,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") tests := []struct { name string @@ -5347,11 +5486,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5392,7 +5527,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5440,7 +5575,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5476,12 +5611,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5505,8 +5640,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(1) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create TCP endpoint. @@ -5552,12 +5688,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5568,7 +5704,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5617,7 +5753,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5638,8 +5774,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.AckNum(uint32(irs) + 1), - checker.SeqNum(uint32(iss + 1)), + checker.TCPAckNum(uint32(irs) + 1), + checker.TCPSeqNum(uint32(iss + 1)), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5657,7 +5793,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err != nil && err != tcpip.ErrWouldBlock { t.Fatalf("Accept failed: %s", err) @@ -5672,7 +5808,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5730,12 +5866,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5800,12 +5936,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5853,12 +5989,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - aep, _, err = ep.Accept() + aep, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5906,13 +6042,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size defined above. @@ -5931,16 +6073,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { time.Sleep(latency) rawEP.SendPacketWithTS([]byte{1}, tsVal) - // Verify that the ACK has the expected window. - wantRcvWnd := receiveBufferSize - wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale)) - rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. - b := make([]byte, int(receiveBufferSize)*2) - offset := 0 - payloadSize := receiveBufferSize - 1 + payloadSize := receiveBufferSize * 2 + b := make([]byte, int(payloadSize)) + worker := (c.EP).(interface { StopWork() ResumeWork() @@ -5949,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Stop the worker goroutine. worker.StopWork() - start := offset - end := offset + payloadSize + start := 0 + end := payloadSize / 2 packetsSent := 0 for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + packetEnd := start + mss + if start+mss > end { + packetEnd = end + } + rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) packetsSent++ } @@ -5961,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // are waiting to be read. worker.ResumeWork() - // Since we read no bytes the window should goto zero till the - // application reads some of the data. - // Discard all intermediate acks except the last one. - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } + // Since we sent almost the full receive buffer worth of data (some may have + // been dropped due to segment overheads), we should get a zero window back. + pkt = c.GetPacket() + tcpHdr := header.TCP(header.IPv4(pkt).Payload()) + gotRcvWnd := tcpHdr.WindowSize() + wantAckNum := tcpHdr.AckNumber() + if got, want := int(gotRcvWnd), 0; got != want { + t.Fatalf("got rcvWnd: %d, want: %d", got, want) } - rawEP.VerifyACKRcvWnd(0) time.Sleep(25 * time.Millisecond) - // Verify that sending more data when window is closed is dropped and - // not acked. + // Verify that sending more data when receiveBuffer is exhausted. rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - // Verify that the stack sends us back an ACK with the sequence number - // of the last packet sent indicating it was dropped. - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - checker.Window(0), - )) - // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { @@ -5996,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Verify that we receive a non-zero window update ACK. When running // under thread santizer this test can end up sending more than 1 // ack, 1 for the non-zero window - p = c.GetPacket() + p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), + checker.TCPAckNum(uint32(wantAckNum)), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { return } - if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) + // We use 10% here as the error margin upwards as the initial window we + // got was afer 1 segment was already in the receive buffer queue. + tolerance := 1.1 + if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { + t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) } }, )) } -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. +// This test verifies that the advertised window is auto-tuned up as the +// application is reading the data that is being received. func TestReceiveBufferAutoTuning(t *testing.T) { const mtu = 1500 const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize @@ -6022,26 +6160,33 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Enable Auto-tuning. stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size used by stack. c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - wantRcvWnd := receiveBufferSize + tsVal := uint32(rawEP.TSVal) + rawEP.NextSeqNum-- + rawEP.SendPacketWithTS(nil, tsVal) + rawEP.NextSeqNum++ + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { return uint16(rcvWnd >> uint16(c.WindowScale)) } @@ -6058,14 +6203,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) { StopWork() ResumeWork() }) - tsVal := rawEP.TSVal - // We are going to do our own computation of what the moderated receive - // buffer should be based on sent/copied data per RTT and verify that - // the advertised window by the stack matches our calculations. - prevCopied := 0 - done := false latency := 1 * time.Millisecond - for i := 0; !done; i++ { + for i := 0; i < 5; i++ { tsVal++ // Stop the worker goroutine. @@ -6087,15 +6226,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Give 1ms for the worker to process the packets. time.Sleep(1 * time.Millisecond) - // Verify that the advertised window on the ACK is reduced by - // the total bytes sent. - expectedWnd := wantRcvWnd - totalSent - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() + lastACK := c.GetPacket() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break } + lastACK = pkt + } + if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { + t.Fatalf("advertised window got: %d, want <= %d", got, want) } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) // Now read all the data from the endpoint and invoke the // moderation API to allow for receive buffer auto-tuning @@ -6120,35 +6264,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ - if i == 0 { // In the first iteration the receiver based RTT is not // yet known as a result the moderation code should not // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) - prevCopied = totalCopied + rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) } else { - rttCopied := totalCopied - if i == 1 { - // The moderation code accumulates copied bytes till - // RTT is established. So add in the bytes sent in - // the first iteration to the total bytes for this - // RTT. - rttCopied += prevCopied - // Now reset it to the initial value used by the - // auto tuning logic. - prevCopied = tcp.InitialCwnd * mss * 2 - } - newWnd := rttCopied<<1 + 16*mss - grow := (newWnd * (rttCopied - prevCopied)) / prevCopied - newWnd += (grow << 1) - if newWnd > maxReceiveBufferSize { - newWnd = maxReceiveBufferSize - done = true + pkt := c.GetPacket() + curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale + // If thew new current window is close maxReceiveBufferSize then terminate + // the loop. This can happen before all iterations are done due to timing + // differences when running the test. + if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { + break } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) - wantRcvWnd = newWnd - prevCopied = rttCopied // Increase the latency after first two iterations to // establish a low RTT value in the receiver since it // only tracks the lowest value. This ensures that when @@ -6161,6 +6290,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) { offset += payloadSize payloadSize *= 2 } + // Check that at the end of our iterations the receive window grew close to the maximum + // permissible size of maxReceiveBufferSize/2 + if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { + t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) + } + } func TestDelayEnabled(t *testing.T) { @@ -6169,7 +6304,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcp.DelayEnabled + delayEnabled tcpip.TCPDelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -6177,17 +6312,17 @@ func TestDelayEnabled(t *testing.T) { } { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) } checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcp.DelayEnabled + var gotDelayEnabled tcpip.TCPDelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } @@ -6293,12 +6428,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6312,8 +6447,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6330,8 +6465,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now send a RST and this should be ignored and not @@ -6359,8 +6494,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6412,12 +6547,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6431,8 +6566,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6449,8 +6584,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Out of order ACK should generate an immediate ACK in @@ -6466,8 +6601,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6519,12 +6654,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6538,8 +6673,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6556,8 +6691,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Send a SYN request w/ sequence number lower than @@ -6602,12 +6737,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.SendPacket(nil, ackHeaders) // Try to accept the connection. - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6625,8 +6760,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 @@ -6675,12 +6811,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6694,8 +6830,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6712,8 +6848,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) time.Sleep(2 * time.Second) @@ -6727,8 +6863,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Sleep for 4 seconds so at this point we are 1 second past the @@ -6756,8 +6892,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { @@ -6775,8 +6911,9 @@ func TestTCPCloseWithData(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } wq := &waiter.Queue{} @@ -6824,12 +6961,12 @@ func TestTCPCloseWithData(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6855,8 +6992,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now write a few bytes and then close the endpoint. @@ -6874,8 +7011,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -6889,8 +7026,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.AckNum(uint32(iss+2)), + checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.TCPAckNum(uint32(iss+2)), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) // First send a partial ACK. @@ -6935,8 +7072,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -6972,8 +7109,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -7007,8 +7144,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7069,8 +7206,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7095,8 +7232,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7112,9 +7249,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { } } -func TestIncreaseWindowOnReceive(t *testing.T) { +func TestIncreaseWindowOnRead(t *testing.T) { // This test ensures that the endpoint sends an ack, - // after recv() when the window grows to more than 1 MSS. + // after read() when the window grows by more than 1 MSS. c := context.New(t, defaultMTU) defer c.Cleanup() @@ -7123,10 +7260,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7139,46 +7275,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) { }) sent += len(data) remain -= len(data) - - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Break once the window drops below defaultMTU/2 + if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { + break + } } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - - // We now have < 1 MSS in the buffer space. Read the data! An - // ack should be sent in response to that. The window was not - // zero, but it grew to larger than MSS. - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) - } - - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) + // We now have < 1 MSS in the buffer space. Read at least > 2 MSS + // worth of data as receive buffer space + read := 0 + // defaultMTU is a good enough estimate for the MSS used for this + // connection. + for read < defaultMTU*2 { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + read += len(v) } - // After reading two packets, we surely crossed MSS. See the ack: + // After reading > MSS worth of data, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7195,10 +7328,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7212,38 +7344,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { sent += len(data) remain -= len(data) - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowLessThanEq(0xffff), checker.TCPFlags(header.TCPFlagAck), ), ) } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - // After reading two packets, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7271,8 +7394,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Send data. This should result in an acceptable endpoint. @@ -7288,14 +7411,14 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7303,8 +7426,8 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestTCPDeferAcceptTimeout(t *testing.T) { @@ -7329,8 +7452,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7341,7 +7464,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) // Send data. This should result in an acceptable endpoint. c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ @@ -7357,14 +7480,14 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7373,8 +7496,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestResetDuringClose(t *testing.T) { @@ -7399,8 +7522,8 @@ func TestResetDuringClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(irs.Add(1))), - checker.AckNum(uint32(iss.Add(5))))) + checker.TCPSeqNum(uint32(irs.Add(1))), + checker.TCPAckNum(uint32(iss.Add(5))))) // Close in a separate goroutine so that we can trigger // a race with the RST we send below. This should not @@ -7462,9 +7585,10 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } for _, tc := range testCases { - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v)) + opt := tcpip.TCPTimeWaitReuseOption(tc.v) + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) if got, want := err, tc.err; got != want { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) } if tc.err != nil { continue @@ -7480,3 +7604,14 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } } } + +// generateRandomPayload generates a random byte slice of the specified length +// causing a fatal test failure if it is unable to do so. +func generateRandomPayload(t *testing.T, n int) []byte { + t.Helper() + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + return buf +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 8edbff964..0f9ed06cd 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -158,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.PayloadLen(len(data)+header.TCPMinimumSize+12), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(true, 0, tsVal+1), ), @@ -180,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) @@ -192,8 +194,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -217,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(false, 0, 0), ), @@ -235,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of + // that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 1f5340cd0..faf51ef95 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -53,11 +53,11 @@ const ( TestPort = 4096 // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" // TestV6Addr is the source address for packets sent to the stack via // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // StackV4MappedAddr is StackAddr as a mapped v6 address. StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr @@ -73,6 +73,18 @@ const ( testInitialSequenceNumber = 789 ) +// StackAddrWithPrefix is StackAddr with its associated prefix length. +var StackAddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackAddr, + PrefixLen: 24, +} + +// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. +var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackV6Addr, + PrefixLen: header.IIDOffsetInIPv6Address * 8, +} + // Headers is used to represent the TCP header fields when building a // new packet. type Headers struct { @@ -133,32 +145,39 @@ type Context struct { // WindowScale is the expected window scale in SYN packets sent by // the stack. WindowScale uint8 + + // RcvdWindowScale is the actual window scale sent by the stack in + // SYN/SYN-ACK. + RcvdWindowScale uint8 } // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) } // Increase minimum RTO in tests to avoid test flakes due to early // retransmit in case the test executors are overloaded and cause timers // to fire earlier than expected. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { - t.Fatalf("failed to set stack-wide minRTO: %s", err) + minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) } // Some of the congestion control tests send up to 640 packets, we so @@ -181,12 +200,20 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -238,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) { c.CheckNoPacketTimeout(errMsg, 1*time.Second) } -// GetPacket reads a packet from the link layer endpoint and verifies +// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies // that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { +// addresses. If no packet is received in the specified timeout it will return +// nil. +func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { - c.t.Fatalf("Packet wasn't written out") return nil } @@ -257,6 +283,14 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -268,6 +302,21 @@ func (c *Context) GetPacket() []byte { return b } +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. +func (c *Context) GetPacket() []byte { + c.t.Helper() + + p := c.GetPacketWithTimeout(5 * time.Second) + if p == nil { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + return p +} + // GetPacketNonBlocking reads a packet from the link layer endpoint // and verifies that it is an IPv4 packet with the expected source // and destination address. If no packet is available it will return @@ -284,6 +333,14 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -447,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.PayloadLen(size+header.TCPMinimumSize+optlen), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -474,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -613,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) } tcpHdr := header.TCP(header.IPv4(b).Payload()) + synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ @@ -630,8 +688,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) checker.TCP( checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -648,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } + c.RcvdWindowScale = uint8(synOpts.WS) c.Port = tcpHdr.SourcePort() } @@ -719,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) } -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { +// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches +// the provided tsVal as well as returns the original packet. +func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { + r.C.t.Helper() // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPTimestampChecker(true, 0, tsVal), ), ) @@ -737,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) opts := tcpSeg.ParsedOptions() r.RecentTS = opts.TSVal + return ackPacket +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + r.C.t.Helper() + _ = r.VerifyAndReturnACKWithTS(tsVal) } // VerifyACKRcvWnd verifies that the window advertised by the incoming ACK // matches the provided rcvWnd. func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { + r.C.t.Helper() ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.Window(rcvWnd), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), + checker.TCPWindow(rcvWnd), ), ) } @@ -768,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPSACKBlockChecker(sackBlocks), ), ) @@ -861,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * tcpCheckers := []checker.TransportChecker{ checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), + checker.TCPSeqNum(uint32(c.IRS) + 1), + checker.TCPAckNum(uint32(iss) + 1), } // Verify that tsEcr of ACK packet is wantOptions.TSVal if the @@ -897,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Mark in context that timestamp option is enabled for this endpoint. c.TimeStampEnabled = true - + c.RcvdWindowScale = uint8(synOptions.WS) return &RawEndpoint{ C: c, SrcPort: tcpSeg.DestinationPort(), @@ -948,12 +1017,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { c.t.Fatalf("Accept failed: %v", err) } @@ -990,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + c.t.Helper() opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 offset += header.EncodeMSSOption(uint32(maxPayload), opts) @@ -1028,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // are present. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) + rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) c.IRS = seqnum.Value(tcp.SequenceNumber()) tcpCheckers := []checker.TransportChecker{ checker.SrcPort(StackPort), checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), + checker.TCPAckNum(uint32(iss) + 1), checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), } @@ -1077,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // Send ACK. c.SendPacket(nil, ackHeaders) + c.RcvdWindowScale = uint8(rcvdSynOptions.WS) c.Port = StackPort return &RawEndpoint{ @@ -1096,7 +1168,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled + var v tcpip.TCPSACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index b5d2d0ba6..c78549424 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c74bc4d94..d57ed5d79 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -139,7 +139,7 @@ type endpoint struct { // multicastMemberships that need to be remvoed when the endpoint is // closed. Protected by the mu mutex. - multicastMemberships []multicastMembership + multicastMemberships map[multicastMembership]struct{} // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -154,6 +154,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // +stateify savable @@ -182,12 +185,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // TTL=1. // // Linux defaults to TTL=1. - multicastTTL: 1, - multicastLoop: true, - rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, - state: StateInitial, - uniqueID: s.UniqueID(), + multicastTTL: 1, + multicastLoop: true, + rcvBufSizeMax: 32 * 1024, + sndBufSizeMax: 32 * 1024, + multicastMemberships: make(map[multicastMembership]struct{}), + state: StateInitial, + uniqueID: s.UniqueID(), } // Override with stack defaults. @@ -237,10 +241,10 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } - for _, mem := range e.multicastMemberships { + for mem := range e.multicastMemberships { e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } - e.multicastMemberships = nil + e.multicastMemberships = make(map[multicastMembership]struct{}) // Close the receive list and drain it. e.rcvMu.Lock() @@ -752,17 +756,15 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - for _, mem := range e.multicastMemberships { - if mem == memToInsert { - return tcpip.ErrPortInUse - } + if _, ok := e.multicastMemberships[memToInsert]; ok { + return tcpip.ErrPortInUse } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } - e.multicastMemberships = append(e.multicastMemberships, memToInsert) + e.multicastMemberships[memToInsert] = struct{}{} case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { @@ -786,18 +788,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() - for i, mem := range e.multicastMemberships { - if mem == memToRemove { - memToRemoveIndex = i - break - } - } - if memToRemoveIndex == -1 { + if _, ok := e.multicastMemberships[memToRemove]; !ok { return tcpip.ErrBadLocalAddress } @@ -805,8 +800,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { return err } - e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + delete(e.multicastMemberships, memToRemove) case *tcpip.BindToDeviceOption: id := tcpip.NICID(*v) @@ -819,6 +813,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -975,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() + case *tcpip.LingerOption: + e.mu.RLock() + *o = e.linger + e.mu.RUnlock() + default: return tcpip.ErrUnknownProtocolOption } @@ -992,6 +996,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // Initialize the UDP header. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = ProtocolNumber length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ @@ -1218,7 +1223,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -1392,15 +1397,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - // Never receive from a multicast address. - if header.IsV4MulticastAddress(id.RemoteAddress) || - header.IsV6MulticastAddress(id.RemoteAddress) { - e.stack.Stats().UDP.InvalidSourceAddress.Increment() - e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - return - } - if !verifyChecksum(r, hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 851e6b635..858c99a45 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - for _, m := range e.multicastMemberships { + for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f65751dd4..da5b1deb2 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -12,18 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package udp contains the implementation of the UDP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing udp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package udp contains the implementation of the UDP transport protocol. package udp import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/waiter" @@ -49,6 +45,7 @@ const ( ) type protocol struct { + stack *stack.Stack } // Number returns the udp protocol number. @@ -57,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid udp packet size. @@ -79,135 +76,30 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } -// HandleUnknownDestinationPacket handles packets targeted at this protocol but -// that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +// HandleUnknownDestinationPacket handles packets that are targeted at this +// protocol but don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true + return stack.UnknownDestinationPacketMalformed } if !verifyChecksum(r, hdr, pkt) { - // Checksum Error. r.Stack().Stats().UDP.ChecksumErrors.Increment() - return true + return stack.UnknownDestinationPacketMalformed } - // Only send ICMP error if the address is not a multicast/broadcast - // v4/v6 address or the source is not the unspecified address. - // - // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { - return true - } - - // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination - // Unreachable messages with code: - // - // 2 (Protocol Unreachable), when the designated transport protocol - // is not supported; or - // - // 3 (Port Unreachable), when the designated transport protocol - // (e.g., UDP) is unable to demultiplex the datagram but has no - // protocol mechanism to inform the sender. - switch len(id.LocalAddress) { - case header.IPv4AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() - return true - } - // As per RFC 1812 Section 4.3.2.3 - // - // ICMP datagram SHOULD contain as much of the original - // datagram as possible without the length of the ICMP - // datagram exceeding 576 bytes - // - // NOTE: The above RFC referenced is different from the original - // recommendation in RFC 1122 where it mentioned that at least 8 - // bytes of the payload must be included. Today linux and other - // systems implement the] RFC1812 definition and not the original - // RFC 1122 requirement. - mtu := int(r.MTU()) - if mtu > header.IPv4MinimumProcessableDatagramSize { - mtu = header.IPv4MinimumProcessableDatagramSize - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize - available := int(mtu) - headerLen - payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - - // The buffers used by pkt may be used elsewhere in the system. - // For example, a raw or packet socket may use what UDP - // considers an unreachable destination. Thus we deep copy pkt - // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) - newHeader = append(newHeader, pkt.TransportHeader().View()...) - payload := newHeader.ToVectorisedView() - payload.AppendView(pkt.Data.ToView()) - payload.CapLength(payloadLen) - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, - Data: payload, - }) - icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) - - case header.IPv6AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() - return true - } - - // As per RFC 4443 section 2.4 - // - // (c) Every ICMPv6 error message (type < 128) MUST include - // as much of the IPv6 offending (invoking) packet (the - // packet that caused the error) as possible without making - // the error message packet exceed the minimum IPv6 MTU - // [IPv6]. - mtu := int(r.MTU()) - if mtu > header.IPv6MinimumMTU { - mtu = header.IPv6MinimumMTU - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize - available := int(mtu) - headerLen - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - payloadLen := len(network) + len(transport) + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) - payload.Append(pkt.Data) - payload.CapLength(payloadLen) - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, - Data: payload, - }) - icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) - } - return true + return stack.UnknownDestinationPacketUnhandled } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -219,11 +111,10 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) - return ok + return parse.UDP(pkt) } // NewProtocol returns a UDP transport protocol. -func NewProtocol() stack.TransportProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0cbc045d8..4e14a5fc5 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -294,8 +294,8 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) } @@ -388,6 +388,10 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } + if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { + c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -528,8 +532,8 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -803,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: tt.handleLocal, }) defer c.cleanup() @@ -924,42 +928,6 @@ func TestReadFromMulticast(t *testing.T) { } } -// TestReadFromMulticaststats checks that a discarded packet -// that that was sent with multicast SOURCE address increments -// the correct counters and that a regular packet does not. -func TestReadFromMulticastStats(t *testing.T) { - t.Helper() - for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - c.injectPacket(flow, payload, false) - - var want uint64 = 0 - if flow.isReverseMulticast() { - want = 1 - } - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { - t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) - } - if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { - t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } - }) - } -} - // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -1462,6 +1430,30 @@ func TestNoChecksum(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1479,16 +1471,19 @@ func TestTTL(t *testing.T) { if flow.isMulticast() { wantTTL = multicastTTL } else { - var p stack.NetworkProtocol + var p stack.NetworkProtocolFactory + var n tcpip.NetworkProtocolNumber if flow.isV4() { - p = ipv4.NewProtocol() + p = ipv4.NewProtocol + n = ipv4.ProtocolNumber } else { - p = ipv6.NewProtocol() + p = ipv6.NewProtocol + n = ipv6.ProtocolNumber } - ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{p}, + }) + ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil) wantTTL = ep.DefaultTTL() ep.Close() } @@ -1512,18 +1507,6 @@ func TestSetTTL(t *testing.T) { c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - ep.Close() - testWrite(c, flow, checker.TTL(wantTTL)) }) } @@ -2343,17 +2326,18 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - requiresBroadcastOpt: true, + remoteAddr: remNetSubnetBcast, + // TODO(gvisor.dev/issue/3938): Once we support marking a route as + // broadcast, this test should require the broadcast option to be set. + requiresBroadcastOpt: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { |