summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go13
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go277
3 files changed, 197 insertions, 101 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 913ea6535..b706438bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -212,7 +212,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
- n.amss = mssForRoute(&n.route)
+ n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
n.maybeEnableTimestamp(rcvdSynOpts)
@@ -380,6 +380,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.portFlags = e.portFlags
n.boundBindToDevice = e.boundBindToDevice
n.boundPortFlags = e.boundPortFlags
+ n.userMSS = e.userMSS
}
// reserveTupleLocked reserves an accepted endpoint's tuple.
@@ -481,9 +482,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
- // TODO(b/143300739): Use the userMSS of the listening socket
- // for accepted sockets.
-
switch {
case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
@@ -514,16 +512,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
// Send SYN without window scaling because we currently
- // dont't encode this information in the cookie.
+ // don't encode this information in the cookie.
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
+ //
+ // Use the user supplied MSS on the listening socket for
+ // new connections, if available.
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: mssForRoute(&s.route),
+ MSS: calculateAdvertisedMSS(e.userMSS, s.route),
}
e.sendSynTCP(&s.route, tcpFields{
id: s.id,
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d08cfe0ff..1ccedebcc 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -667,7 +667,8 @@ func (e *endpoint) UniqueID() uint64 {
// r, it will be used; otherwise, the maximum possible MSS will be used.
func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
- maxMSS := mssForRoute(&r)
+ // TODO(b/143359391): Respect TCP Min and Max size.
+ maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
if userMSS != 0 && userMSS < maxMSS {
return userMSS
@@ -2966,8 +2967,3 @@ func (e *endpoint) Wait() {
<-notifyCh
}
}
-
-func mssForRoute(r *stack.Route) uint16 {
- // TODO(b/143359391): Respect TCP Min and Max size.
- return uint16(r.MTU() - header.TCPMinimumSize)
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 0f7e958e4..55ae09a2f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -750,128 +750,227 @@ func TestSimpleReceive(t *testing.T) {
)
}
-// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
-// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when
+// creating a new active TCP socket. It should be present in the sent TCP
// SYN segment.
-func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+func TestUserSuppliedMSSOnConnect(t *testing.T) {
const mtu = 5000
- const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS int
- expMSS uint16
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ connectAddr tcpip.Address
+ checker func(*testing.T, *context.Context, uint16, int)
+ maxMSS uint16
}{
{
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ connectAddr: context.TestAddr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
},
{
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(true)
+ },
+ connectAddr: context.TestV6Addr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
- c.Create(-1)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
- }
+ ip.createEP(c)
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+ // Set the MSS socket option.
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
- // Start connection attempt to IPv4 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("unexpected return value from Connect: %s", err)
- }
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
- // Receive SYN packet with our user supplied MSS.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
+ if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ ip.checker(t, c, test.expMSS, ws)
+ })
+ }
})
}
}
-// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
-// creating a new active IPv6 TCP socket. It should be present in the sent TCP
-// SYN segment.
-func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
- const mtu = 5000
- const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
+// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used
+// when completing the handshake for a new TCP connection from a TCP
+// listening socket. It should be present in the sent TCP SYN-ACK segment.
+func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
+ const (
+ nonSynCookieAccepts = 2
+ totalAccepts = 4
+ mtu = 5000
+ )
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ sendPkt func(*context.Context, *context.Headers)
+ checker func(*testing.T, *context.Context, uint16, uint16)
+ maxMSS uint16
}{
{
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendPacket(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
},
{
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(false)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendV6Packet(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
- c.CreateV6Endpoint(true)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
- }
+ ip.createEP(c)
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+ // Set the SynRcvd threshold to force a syn cookie based accept to happen.
+ opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err)
+ }
- // Start connection attempt to IPv6 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("unexpected return value from Connect: %s", err)
- }
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
- // Receive SYN packet with our user supplied MSS.
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ bindAddr := tcpip.FullAddress{Port: context.StackPort}
+ if err := c.EP.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s:", bindAddr, err)
+ }
+
+ if err := c.EP.Listen(totalAccepts); err != nil {
+ t.Fatalf("Listen(%d): %s:", totalAccepts, err)
+ }
+
+ // The first nonSynCookieAccepts packets sent will trigger a gorooutine
+ // based accept. The rest will trigger a cookie based accept.
+ for i := 0; i < totalAccepts; i++ {
+ // Send a SYN requests.
+ iss := seqnum.Value(i)
+ srcPort := context.TestPort + uint16(i)
+ ip.sendPkt(c, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ ip.checker(t, c, srcPort, test.expMSS)
+ }
+ })
+ }
})
}
}
-
func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()