summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go6
-rw-r--r--pkg/tcpip/transport/tcp/connect.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go24
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go124
4 files changed, 150 insertions, 6 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 65c346046..1dd00d026 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -400,6 +400,9 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+ // TODO(b/143300739): Use the userMSS of the listening socket
+ // for accepted sockets.
+
switch s.flags {
case header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
@@ -434,13 +437,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
- mss := mssForRoute(&s.route)
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: uint16(mss),
+ MSS: mssForRoute(&s.route),
}
e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 790e89cc3..ca982c451 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -442,7 +442,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
- h.ep.amss = mssForRoute(&h.ep.route)
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6ca0d73a9..8234a8b53 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -411,7 +411,7 @@ type endpoint struct {
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
- userMSS int
+ userMSS uint16
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
@@ -504,6 +504,21 @@ type endpoint struct {
stats Stats `state:"nosave"`
}
+// calculateAdvertisedMSS calculates the MSS to advertise.
+//
+// If userMSS is non-zero and is not greater than the maximum possible MSS for
+// r, it will be used; otherwise, the maximum possible MSS will be used.
+func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+ // The maximum possible MSS is dependent on the route.
+ maxMSS := mssForRoute(&r)
+
+ if userMSS != 0 && userMSS < maxMSS {
+ return userMSS
+ }
+
+ return maxMSS
+}
+
// StopWork halts packet processing. Only to be used in tests.
func (e *endpoint) StopWork() {
e.workMu.Lock()
@@ -752,7 +767,9 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
- routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2
+
+ // Use the user supplied MSS, if available.
+ routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
@@ -1206,7 +1223,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
e.mu.Lock()
- e.userMSS = int(userMSS)
+ e.userMSS = uint16(userMSS)
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
@@ -2383,5 +2400,6 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
}
func mssForRoute(r *stack.Route) uint16 {
+ // TODO(b/143359391): Respect TCP Min and Max size.
return uint16(r.MTU() - header.TCPMinimumSize)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6d808328c..126f26ed3 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -474,6 +474,130 @@ func TestSimpleReceive(t *testing.T) {
)
}
+// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
+// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv4 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
+// creating a new active IPv6 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv6 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
func TestTOSV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()