diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 277 |
3 files changed, 197 insertions, 101 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 913ea6535..b706438bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -212,7 +212,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) - n.amss = mssForRoute(&n.route) + n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) @@ -380,6 +380,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.portFlags = e.portFlags n.boundBindToDevice = e.boundBindToDevice n.boundPortFlags = e.boundPortFlags + n.userMSS = e.userMSS } // reserveTupleLocked reserves an accepted endpoint's tuple. @@ -481,9 +482,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } - // TODO(b/143300739): Use the userMSS of the listening socket - // for accepted sockets. - switch { case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) @@ -514,16 +512,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) // Send SYN without window scaling because we currently - // dont't encode this information in the cookie. + // don't encode this information in the cookie. // // Enable Timestamp option if the original syn did have // the timestamp option specified. + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: mssForRoute(&s.route), + MSS: calculateAdvertisedMSS(e.userMSS, s.route), } e.sendSynTCP(&s.route, tcpFields{ id: s.id, diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d08cfe0ff..1ccedebcc 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -667,7 +667,8 @@ func (e *endpoint) UniqueID() uint64 { // r, it will be used; otherwise, the maximum possible MSS will be used. func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { // The maximum possible MSS is dependent on the route. - maxMSS := mssForRoute(&r) + // TODO(b/143359391): Respect TCP Min and Max size. + maxMSS := uint16(r.MTU() - header.TCPMinimumSize) if userMSS != 0 && userMSS < maxMSS { return userMSS @@ -2966,8 +2967,3 @@ func (e *endpoint) Wait() { <-notifyCh } } - -func mssForRoute(r *stack.Route) uint16 { - // TODO(b/143359391): Respect TCP Min and Max size. - return uint16(r.MTU() - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0f7e958e4..55ae09a2f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -750,128 +750,227 @@ func TestSimpleReceive(t *testing.T) { ) } -// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when -// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when +// creating a new active TCP socket. It should be present in the sent TCP // SYN segment. -func TestUserSuppliedMSSOnConnectV4(t *testing.T) { +func TestUserSuppliedMSSOnConnect(t *testing.T) { const mtu = 5000 - const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS int - expMSS uint16 + + ips := []struct { + name string + createEP func(*context.Context) + connectAddr tcpip.Address + checker func(*testing.T, *context.Context, uint16, int) + maxMSS uint16 }{ { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + connectAddr: context.TestAddr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, }, { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(true) + }, + connectAddr: context.TestV6Addr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - c.Create(-1) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) - } + ip.createEP(c) - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + // Set the MSS socket option. + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - // Start connection attempt to IPv4 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) - } + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - // Receive SYN packet with our user supplied MSS. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} + if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } + + // Receive SYN packet with our user supplied MSS. + ip.checker(t, c, test.expMSS, ws) + }) + } }) } } -// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when -// creating a new active IPv6 TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnectV6(t *testing.T) { - const mtu = 5000 - const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize - tests := []struct { - name string - setMSS uint16 - expMSS uint16 +// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used +// when completing the handshake for a new TCP connection from a TCP +// listening socket. It should be present in the sent TCP SYN-ACK segment. +func TestUserSuppliedMSSOnListenAccept(t *testing.T) { + const ( + nonSynCookieAccepts = 2 + totalAccepts = 4 + mtu = 5000 + ) + + ips := []struct { + name string + createEP func(*context.Context) + sendPkt func(*context.Context, *context.Headers) + checker func(*testing.T, *context.Context, uint16, uint16) + maxMSS uint16 }{ { - "EqualToMaxMSS", - maxMSS, - maxMSS, - }, - { - "LessThanMTU", - maxMSS - 1, - maxMSS - 1, + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendPacket(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, }, { - "GreaterThanMTU", - maxMSS + 1, - maxMSS, + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(false) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendV6Packet(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - c.CreateV6Endpoint(true) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) - } + ip.createEP(c) - // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + // Set the SynRcvd threshold to force a syn cookie based accept to happen. + opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err) + } - // Start connection attempt to IPv6 address. - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) - } + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - // Receive SYN packet with our user supplied MSS. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + bindAddr := tcpip.FullAddress{Port: context.StackPort} + if err := c.EP.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s:", bindAddr, err) + } + + if err := c.EP.Listen(totalAccepts); err != nil { + t.Fatalf("Listen(%d): %s:", totalAccepts, err) + } + + // The first nonSynCookieAccepts packets sent will trigger a gorooutine + // based accept. The rest will trigger a cookie based accept. + for i := 0; i < totalAccepts; i++ { + // Send a SYN requests. + iss := seqnum.Value(i) + srcPort := context.TestPort + uint16(i) + ip.sendPkt(c, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + ip.checker(t, c, srcPort, test.expMSS) + } + }) + } }) } } - func TestSendRstOnListenerRxSynAckV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() |