diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 52 |
4 files changed, 87 insertions, 5 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 965779a68..00cb39560 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -596,9 +596,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { length := uint16(hdr.UsedLength() + data.Size()) xsum := r.PseudoHeaderChecksum(ProtocolNumber) - for _, v := range data.Views() { - xsum = header.Checksum(v, xsum) - } + xsum = header.ChecksumVV(data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 673a65c31..0b395b5b0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -243,6 +243,16 @@ type endpoint struct { connectingAddress tcpip.Address } +// StopWork halts packet processing. Only to be used in tests. +func (e *endpoint) StopWork() { + e.workMu.Lock() +} + +// ResumeWork resumes packet processing. Only to be used in tests. +func (e *endpoint) ResumeWork() { + e.workMu.Unlock() +} + // keepalive is a synchronization wrapper used to appease stateify. See the // comment in endpoint, where it is used. // diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index b260c0e57..71bfcbd36 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -403,15 +403,36 @@ func (s *sender) sendData() { // TODO: We currently don't merge multiple send buffers // into one segment if they happen to fit. We should do that // eventually. - var seg *segment + seg := s.writeNext end := s.sndUna.Add(s.sndWnd) var dataSent bool - for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + for next := (*segment)(nil); seg != nil && s.outstanding < s.sndCwnd; seg = next { + next = seg.Next() + // We abuse the flags field to determine if we have already // assigned a sequence number to this segment. if seg.flags == 0 { seg.sequenceNumber = s.sndNxt seg.flags = flagAck | flagPsh + // Merge segments if allowed. + if seg.data.Size() != 0 { + available := int(seg.sequenceNumber.Size(end)) + if available > limit { + available = limit + } + + for next != nil && next.data.Size() != 0 { + if seg.data.Size()+next.data.Size() > available { + break + } + + seg.data.Append(&next.data) + + // Consume the segment that we just merged in. + s.writeList.Remove(next) + next = next.Next() + } + } } var segEnd seqnum.Value @@ -442,6 +463,7 @@ func (s *sender) sendData() { nSeg.data.TrimFront(available) nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) s.writeList.InsertAfter(seg, nSeg) + next = nSeg seg.data.CapLength(available) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 04e046257..75868c4a2 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1254,6 +1254,58 @@ func TestZeroScaledWindowReceive(t *testing.T) { ) } +func TestSegmentMerging(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Prevent the endpoint from processing packets. + worker := c.EP.(interface { + StopWork() + ResumeWork() + }) + worker.StopWork() + + var allData []byte + for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + allData = append(allData, data...) + view := buffer.NewViewFromBytes(data) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %v", i+1, err) + } + } + + // Let the endpoint process the segments that we just sent. + worker.ResumeWork() + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(allData)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { + t.Fatalf("got data = %v, want = %v", got, allData) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), + RcvWnd: 30000, + }) +} + func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { payloadMultiplier := 10 dataLen := payloadMultiplier * maxPayload |