summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/buffer/view.go6
-rw-r--r--pkg/tcpip/header/BUILD1
-rw-r--r--pkg/tcpip/header/checksum.go39
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/tcp/snd.go26
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go52
7 files changed, 129 insertions, 9 deletions
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 24479ea40..6c70e0d69 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -144,3 +144,9 @@ func (vv VectorisedView) ToView() View {
func (vv VectorisedView) Views() []View {
return vv.views
}
+
+// Append appends the views in a vectorised view to this vectorised view.
+func (vv *VectorisedView) Append(vv2 *VectorisedView) {
+ vv.views = append(vv.views, vv2.views...)
+ vv.size += vv2.size
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 8f22ba3a5..66b37720c 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -22,6 +22,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/seqnum",
],
)
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 12f208fde..2c2fcbd9b 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -18,12 +18,11 @@ package header
import (
"gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
)
-// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
-// given byte array.
-func Checksum(buf []byte, initial uint16) uint16 {
- v := uint32(initial)
+func calculateChecksum(buf []byte, initial uint32) uint16 {
+ v := initial
l := len(buf)
if l&1 != 0 {
@@ -38,8 +37,40 @@ func Checksum(buf []byte, initial uint16) uint16 {
return ChecksumCombine(uint16(v), uint16(v>>16))
}
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func Checksum(buf []byte, initial uint16) uint16 {
+ return calculateChecksum(buf, uint32(initial))
+}
+
+// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
+ var odd bool
+ sum := initial
+ for _, v := range vv.Views() {
+ if len(v) == 0 {
+ continue
+ }
+ s := uint32(sum)
+ if odd {
+ s += uint32(v[0])
+ v = v[1:]
+ }
+ odd = len(v)&1 != 0
+ sum = calculateChecksum(v, s)
+ }
+ return sum
+}
+
// ChecksumCombine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
+//
+// Note that checksum a must have been computed on an even number of bytes.
func ChecksumCombine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 965779a68..00cb39560 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -596,9 +596,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
length := uint16(hdr.UsedLength() + data.Size())
xsum := r.PseudoHeaderChecksum(ProtocolNumber)
- for _, v := range data.Views() {
- xsum = header.Checksum(v, xsum)
- }
+ xsum = header.ChecksumVV(data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 673a65c31..0b395b5b0 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -243,6 +243,16 @@ type endpoint struct {
connectingAddress tcpip.Address
}
+// StopWork halts packet processing. Only to be used in tests.
+func (e *endpoint) StopWork() {
+ e.workMu.Lock()
+}
+
+// ResumeWork resumes packet processing. Only to be used in tests.
+func (e *endpoint) ResumeWork() {
+ e.workMu.Unlock()
+}
+
// keepalive is a synchronization wrapper used to appease stateify. See the
// comment in endpoint, where it is used.
//
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index b260c0e57..71bfcbd36 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -403,15 +403,36 @@ func (s *sender) sendData() {
// TODO: We currently don't merge multiple send buffers
// into one segment if they happen to fit. We should do that
// eventually.
- var seg *segment
+ seg := s.writeNext
end := s.sndUna.Add(s.sndWnd)
var dataSent bool
- for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ for next := (*segment)(nil); seg != nil && s.outstanding < s.sndCwnd; seg = next {
+ next = seg.Next()
+
// We abuse the flags field to determine if we have already
// assigned a sequence number to this segment.
if seg.flags == 0 {
seg.sequenceNumber = s.sndNxt
seg.flags = flagAck | flagPsh
+ // Merge segments if allowed.
+ if seg.data.Size() != 0 {
+ available := int(seg.sequenceNumber.Size(end))
+ if available > limit {
+ available = limit
+ }
+
+ for next != nil && next.data.Size() != 0 {
+ if seg.data.Size()+next.data.Size() > available {
+ break
+ }
+
+ seg.data.Append(&next.data)
+
+ // Consume the segment that we just merged in.
+ s.writeList.Remove(next)
+ next = next.Next()
+ }
+ }
}
var segEnd seqnum.Value
@@ -442,6 +463,7 @@ func (s *sender) sendData() {
nSeg.data.TrimFront(available)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(available))
s.writeList.InsertAfter(seg, nSeg)
+ next = nSeg
seg.data.CapLength(available)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 04e046257..75868c4a2 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1254,6 +1254,58 @@ func TestZeroScaledWindowReceive(t *testing.T) {
)
}
+func TestSegmentMerging(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Prevent the endpoint from processing packets.
+ worker := c.EP.(interface {
+ StopWork()
+ ResumeWork()
+ })
+ worker.StopWork()
+
+ var allData []byte
+ for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
+ allData = append(allData, data...)
+ view := buffer.NewViewFromBytes(data)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %v", i+1, err)
+ }
+ }
+
+ // Let the endpoint process the segments that we just sent.
+ worker.ResumeWork()
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(allData)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) {
+ t.Fatalf("got data = %v, want = %v", got, allData)
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))),
+ RcvWnd: 30000,
+ })
+}
+
func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
payloadMultiplier := 10
dataLen := payloadMultiplier * maxPayload