diff options
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 53 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/provider.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 47 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 137 |
6 files changed, 230 insertions, 57 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 39a0b9941..d14bbad01 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -157,7 +157,13 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) *fs.File { +func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { + if skType == transport.SockStream { + if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + } + dirent := socket.NewDirent(t, epsocketDevice) defer dirent.DecRef() return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true}, &SocketOperations{ @@ -165,7 +171,7 @@ func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Qu family: family, Endpoint: endpoint, skType: skType, - }) + }), nil } var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) @@ -426,10 +432,10 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // tcpip.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. - ep, wq, err := s.Endpoint.Accept() - if err != nil { - if err != tcpip.ErrWouldBlock || !blocking { - return 0, nil, 0, syserr.TranslateNetstackError(err) + ep, wq, terr := s.Endpoint.Accept() + if terr != nil { + if terr != tcpip.ErrWouldBlock || !blocking { + return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error @@ -439,7 +445,10 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns := New(t, s.family, s.skType, wq, ep) + ns, err := New(t, s.family, s.skType, wq, ep) + if err != nil { + return 0, nil, 0, err + } defer ns.DecRef() if flags&linux.SOCK_NONBLOCK != 0 { @@ -632,7 +641,22 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return nil, syserr.ErrInvalidArgument } - var v tcpip.NoDelayOption + var v tcpip.DelayOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + if v == 0 { + return int32(1), nil + } + return int32(0), nil + + case syscall.TCP_CORK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.CorkOption if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -748,7 +772,18 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.NoDelayOption(v))) + var o tcpip.DelayOption + if v == 0 { + o = 1 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(o)) + case syscall.TCP_CORK: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v))) } case syscall.SOL_IPV6: switch name { diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index 686554437..0184d8e3e 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -88,7 +88,7 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int return nil, syserr.TranslateNetstackError(e) } - return New(t, p.family, stype, wq, ep), nil + return New(t, p.family, stype, wq, ep) } // Pair just returns nil sockets (not supported). diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 413aee6c6..8e2fe70ee 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -420,10 +420,14 @@ type ReceiveQueueSizeOption int // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int -// NoDelayOption is used by SetSockOpt/GetSockOpt to specify if data should be +// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be // sent out immediately by the transport protocol. For TCP, it determines if the // Nagle algorithm is on or off. -type NoDelayOption int +type DelayOption int + +// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be +// held until segments are full by the TCP transport protocol. +type CorkOption int // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() // should allow reuse of local address. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0b395b5b0..96a546aa7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -162,10 +162,19 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo + // delay enables Nagle's algorithm. + // + // delay is a boolean (0 is false) and must be accessed atomically. + delay uint32 + + // cork holds back segments until full. + // + // cork is a boolean (0 is false) and must be accessed atomically. + cork uint32 + // The options below aren't implemented, but we remember the user // settings because applications expect to be able to set/query these // options. - noDelay bool reuseAddr bool // segmentQueue is used to hand received segments to the protocol @@ -276,7 +285,6 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite rcvBufSize: DefaultBufferSize, sndBufSize: DefaultBufferSize, sndMTU: int(math.MaxInt32), - noDelay: false, reuseAddr: true, keepalive: keepalive{ // Linux defaults. @@ -643,10 +651,24 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { - case tcpip.NoDelayOption: - e.mu.Lock() - e.noDelay = v != 0 - e.mu.Unlock() + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + } else { + atomic.StoreUint32(&e.cork, 1) + } + + // Handle the corked data. + e.sndWaker.Assert() + return nil case tcpip.ReuseAddressOption: @@ -812,13 +834,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.ReceiveQueueSizeOption(v) return nil - case *tcpip.NoDelayOption: - e.mu.RLock() - v := e.noDelay - e.mu.RUnlock() + case *tcpip.DelayOption: + *o = 0 + if v := atomic.LoadUint32(&e.delay); v != 0 { + *o = 1 + } + return nil + case *tcpip.CorkOption: *o = 0 - if v { + if v := atomic.LoadUint32(&e.cork); v != 0 { *o = 1 } return nil diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 4482d8d07..f6dc7520b 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -17,6 +17,7 @@ package tcp import ( "math" "sync" + "sync/atomic" "time" "gvisor.googlesource.com/gvisor/pkg/sleep" @@ -409,8 +410,6 @@ func (s *sender) sendData() { // We abuse the flags field to determine if we have already // assigned a sequence number to this segment. if seg.flags == 0 { - seg.sequenceNumber = s.sndNxt - seg.flags = flagAck | flagPsh // Merge segments if allowed. if seg.data.Size() != 0 { available := int(seg.sequenceNumber.Size(end)) @@ -418,8 +417,20 @@ func (s *sender) sendData() { available = limit } + // nextTooBig indicates that the next segment was too + // large to entirely fit in the current segment. It would + // be possible to split the next segment and merge the + // portion that fits, but unexpectedly splitting segments + // can have user visible side-effects which can break + // applications. For example, RFC 7766 section 8 says + // that the length and data of a DNS response should be + // sent in the same TCP segment to avoid triggering bugs + // in poorly written DNS implementations. + var nextTooBig bool + for next != nil && next.data.Size() != 0 { if seg.data.Size()+next.data.Size() > available { + nextTooBig = true break } @@ -429,7 +440,32 @@ func (s *sender) sendData() { s.writeList.Remove(next) next = next.Next() } + + if !nextTooBig && seg.data.Size() < available { + // Segment is not full. + if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { + // Nagle's algorithm. From Wikipedia: + // Nagle's algorithm works by combining a number of + // small outgoing messages and sending them all at + // once. Specifically, as long as there is a sent + // packet for which the sender has received no + // acknowledgment, the sender should keep buffering + // its output until it has a full packet's worth of + // output, thus allowing output to be sent all at + // once. + break + } + if atomic.LoadUint32(&s.ep.cork) != 0 { + // Hold back the segment until full. + break + } + } } + + // Assign flags. We don't do it above so that we can merge + // additional data if Nagle holds the segment. + seg.sequenceNumber = s.sndNxt + seg.flags = flagAck | flagPsh } var segEnd seqnum.Value diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 75868c4a2..8155e4ed8 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1255,20 +1255,92 @@ func TestZeroScaledWindowReceive(t *testing.T) { } func TestSegmentMerging(t *testing.T) { + tests := []struct { + name string + stop func(tcpip.Endpoint) + resume func(tcpip.Endpoint) + }{ + { + "stop work", + func(ep tcpip.Endpoint) { + ep.(interface{ StopWork() }).StopWork() + }, + func(ep tcpip.Endpoint) { + ep.(interface{ ResumeWork() }).ResumeWork() + }, + }, + { + "cork", + func(ep tcpip.Endpoint) { + ep.SetSockOpt(tcpip.CorkOption(1)) + }, + func(ep tcpip.Endpoint) { + ep.SetSockOpt(tcpip.CorkOption(0)) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Prevent the endpoint from processing packets. + test.stop(c.EP) + + var allData []byte + for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + allData = append(allData, data...) + view := buffer.NewViewFromBytes(data) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %v", i+1, err) + } + } + + // Let the endpoint process the segments that we just sent. + test.resume(c.EP) + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(allData)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { + t.Fatalf("got data = %v, want = %v", got, allData) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), + RcvWnd: 30000, + }) + }) + } +} + +func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() c.CreateConnected(789, 30000, nil) - // Prevent the endpoint from processing packets. - worker := c.EP.(interface { - StopWork() - ResumeWork() - }) - worker.StopWork() + c.EP.SetSockOpt(tcpip.DelayOption(1)) var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { @@ -1276,34 +1348,35 @@ func TestSegmentMerging(t *testing.T) { } } - // Let the endpoint process the segments that we just sent. - worker.ResumeWork() + seq := c.IRS.Add(1) + for _, want := range [][]byte{allData[:1], allData[1:]} { + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(want)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(seq)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) + if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { + t.Fatalf("got data = %v, want = %v", got, want) + } - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) + seq = seq.Add(seqnum.Size(len(want))) + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seq, + RcvWnd: 30000, + }) } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), - RcvWnd: 30000, - }) } func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { |