diff options
-rw-r--r-- | pkg/tcpip/buffer/view.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/header/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/header/checksum.go | 39 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 52 |
7 files changed, 129 insertions, 9 deletions
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 24479ea40..6c70e0d69 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -144,3 +144,9 @@ func (vv VectorisedView) ToView() View { func (vv VectorisedView) Views() []View { return vv.views } + +// Append appends the views in a vectorised view to this vectorised view. +func (vv *VectorisedView) Append(vv2 *VectorisedView) { + vv.views = append(vv.views, vv2.views...) + vv.size += vv2.size +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 8f22ba3a5..66b37720c 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -22,6 +22,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/seqnum", ], ) diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 12f208fde..2c2fcbd9b 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -18,12 +18,11 @@ package header import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" ) -// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. -func Checksum(buf []byte, initial uint16) uint16 { - v := uint32(initial) +func calculateChecksum(buf []byte, initial uint32) uint16 { + v := initial l := len(buf) if l&1 != 0 { @@ -38,8 +37,40 @@ func Checksum(buf []byte, initial uint16) uint16 { return ChecksumCombine(uint16(v), uint16(v>>16)) } +// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// given byte array. +// +// The initial checksum must have been computed on an even number of bytes. +func Checksum(buf []byte, initial uint16) uint16 { + return calculateChecksum(buf, uint32(initial)) +} + +// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in +// the given VectorizedView. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { + var odd bool + sum := initial + for _, v := range vv.Views() { + if len(v) == 0 { + continue + } + s := uint32(sum) + if odd { + s += uint32(v[0]) + v = v[1:] + } + odd = len(v)&1 != 0 + sum = calculateChecksum(v, s) + } + return sum +} + // ChecksumCombine combines the two uint16 to form their checksum. This is done // by adding them and the carry. +// +// Note that checksum a must have been computed on an even number of bytes. func ChecksumCombine(a, b uint16) uint16 { v := uint32(a) + uint32(b) return uint16(v + v>>16) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 965779a68..00cb39560 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -596,9 +596,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { length := uint16(hdr.UsedLength() + data.Size()) xsum := r.PseudoHeaderChecksum(ProtocolNumber) - for _, v := range data.Views() { - xsum = header.Checksum(v, xsum) - } + xsum = header.ChecksumVV(data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 673a65c31..0b395b5b0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -243,6 +243,16 @@ type endpoint struct { connectingAddress tcpip.Address } +// StopWork halts packet processing. Only to be used in tests. +func (e *endpoint) StopWork() { + e.workMu.Lock() +} + +// ResumeWork resumes packet processing. Only to be used in tests. +func (e *endpoint) ResumeWork() { + e.workMu.Unlock() +} + // keepalive is a synchronization wrapper used to appease stateify. See the // comment in endpoint, where it is used. // diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index b260c0e57..71bfcbd36 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -403,15 +403,36 @@ func (s *sender) sendData() { // TODO: We currently don't merge multiple send buffers // into one segment if they happen to fit. We should do that // eventually. - var seg *segment + seg := s.writeNext end := s.sndUna.Add(s.sndWnd) var dataSent bool - for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + for next := (*segment)(nil); seg != nil && s.outstanding < s.sndCwnd; seg = next { + next = seg.Next() + // We abuse the flags field to determine if we have already // assigned a sequence number to this segment. if seg.flags == 0 { seg.sequenceNumber = s.sndNxt seg.flags = flagAck | flagPsh + // Merge segments if allowed. + if seg.data.Size() != 0 { + available := int(seg.sequenceNumber.Size(end)) + if available > limit { + available = limit + } + + for next != nil && next.data.Size() != 0 { + if seg.data.Size()+next.data.Size() > available { + break + } + + seg.data.Append(&next.data) + + // Consume the segment that we just merged in. + s.writeList.Remove(next) + next = next.Next() + } + } } var segEnd seqnum.Value @@ -442,6 +463,7 @@ func (s *sender) sendData() { nSeg.data.TrimFront(available) nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) s.writeList.InsertAfter(seg, nSeg) + next = nSeg seg.data.CapLength(available) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 04e046257..75868c4a2 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1254,6 +1254,58 @@ func TestZeroScaledWindowReceive(t *testing.T) { ) } +func TestSegmentMerging(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Prevent the endpoint from processing packets. + worker := c.EP.(interface { + StopWork() + ResumeWork() + }) + worker.StopWork() + + var allData []byte + for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + allData = append(allData, data...) + view := buffer.NewViewFromBytes(data) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %v", i+1, err) + } + } + + // Let the endpoint process the segments that we just sent. + worker.ResumeWork() + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(allData)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { + t.Fatalf("got data = %v, want = %v", got, allData) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), + RcvWnd: 30000, + }) +} + func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { payloadMultiplier := 10 dataLen := payloadMultiplier * maxPayload |