summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/seqnum/seqnum.go5
-rw-r--r--pkg/tcpip/transport/tcp/BUILD10
-rw-r--r--pkg/tcpip/transport/tcp/accept.go14
-rw-r--r--pkg/tcpip/transport/tcp/connect.go15
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go13
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go23
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go74
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD1
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go13
9 files changed, 125 insertions, 43 deletions
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
index b40a3c212..d3bea7de4 100644
--- a/pkg/tcpip/seqnum/seqnum.go
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -46,11 +46,6 @@ func (v Value) InWindow(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
-// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
-func Overlap(a Value, b Size, x Value, y Size) bool {
- return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
-}
-
// Add calculates the sequence number following the [v, v+s) window.
func (v Value) Add(s Size) Value {
return v + Value(s)
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index edb7718a6..61426623c 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -109,3 +109,13 @@ go_test(
"//runsc/testutil",
],
)
+
+go_test(
+ name = "rcv_test",
+ size = "small",
+ srcs = ["rcv_test.go"],
+ deps = [
+ ":tcp",
+ "//pkg/tcpip/seqnum",
+ ],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 5bb243e3b..e6a23c978 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -101,7 +101,7 @@ type listenContext struct {
// v6Only is true if listenEP is a dual stack socket and has the
// IPV6_V6ONLY option set.
- v6only bool
+ v6Only bool
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
@@ -126,12 +126,12 @@ func timeStamp() uint32 {
}
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stk,
rcvWnd: rcvWnd,
hasher: sha1.New(),
- v6only: v6only,
+ v6Only: v6Only,
netProto: netProto,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
@@ -207,7 +207,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
netProto = s.route.NetProto
}
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6only
+ n.v6only = l.v6Only
n.ID = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
@@ -293,7 +293,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
}
// Perform the 3-way handshake.
- h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
+ h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.mu.Unlock()
ep.Close()
@@ -613,8 +613,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
- v6only := e.v6only
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
+ v6Only := e.v6only
+ ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 368865911..76e27bf26 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -105,24 +105,11 @@ type handshake struct {
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- rcvWndScale := ep.rcvWndScaleForHandshake()
-
- // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
- // window offered in SYN won't be reduced due to the loss of precision if
- // window scaling is enabled after the handshake.
- rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
-
- // Ensure we can always accept at least 1 byte if the scale specified
- // was too high for the provided rcvWnd.
- if rcvWnd == 0 {
- rcvWnd = 1
- }
-
h := handshake{
ep: ep,
active: true,
rcvWnd: rcvWnd,
- rcvWndScale: int(rcvWndScale),
+ rcvWndScale: ep.rcvWndScaleForHandshake(),
}
h.resetState()
return h
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7e8def82d..45f2aa78b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1062,6 +1062,19 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
+ rcvWndScale := e.rcvWndScaleForHandshake()
+
+ // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
+ // window offered in SYN won't be reduced due to the loss of precision if
+ // window scaling is enabled after the handshake.
+ rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+
+ // Ensure we can always accept at least 1 byte if the scale specified
+ // was too high for the provided rcvWnd.
+ if rcvWnd == 0 {
+ rcvWnd = 1
+ }
+
return rcvWnd
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index caf8977b3..a4b73b588 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -70,13 +70,24 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
// acceptable checks if the segment sequence number range is acceptable
// according to the table on page 26 of RFC 793.
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- rcvWnd := r.rcvNxt.Size(r.rcvAcc)
- if rcvWnd == 0 {
- return segLen == 0 && segSeq == r.rcvNxt
- }
+ return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc)
+}
- return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
- seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+// Acceptable checks if a segment that starts at segSeq and has length segLen is
+// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
+// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
+func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
+ if rcvNxt == rcvAcc {
+ return segLen == 0 && segSeq == rcvNxt
+ }
+ if segLen == 0 {
+ // rcvWnd is incremented by 1 because that is Linux's behavior despite the
+ // RFC.
+ return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
+ }
+ // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
+ // the payload, so we'll accept any payload that overlaps the receieve window.
+ return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc)
}
// getSendParams returns the parameters needed by the sender when building
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
new file mode 100644
index 000000000..dc02729ce
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rcv_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+func TestAcceptable(t *testing.T) {
+ for _, tt := range []struct {
+ segSeq seqnum.Value
+ segLen seqnum.Size
+ rcvNxt, rcvAcc seqnum.Value
+ want bool
+ }{
+ // The segment is smaller than the window.
+ {105, 2, 100, 104, false},
+ {105, 2, 101, 105, false},
+ {105, 2, 102, 106, true},
+ {105, 2, 103, 107, true},
+ {105, 2, 104, 108, true},
+ {105, 2, 105, 109, true},
+ {105, 2, 106, 110, true},
+ {105, 2, 107, 111, false},
+
+ // The segment is larger than the window.
+ {105, 4, 103, 105, false},
+ {105, 4, 104, 106, true},
+ {105, 4, 105, 107, true},
+ {105, 4, 106, 108, true},
+ {105, 4, 107, 109, true},
+ {105, 4, 108, 110, true},
+ {105, 4, 109, 111, false},
+ {105, 4, 110, 112, false},
+
+ // The segment has no width.
+ {105, 0, 100, 102, false},
+ {105, 0, 101, 103, false},
+ {105, 0, 102, 104, false},
+ {105, 0, 103, 105, true},
+ {105, 0, 104, 106, true},
+ {105, 0, 105, 107, true},
+ {105, 0, 106, 108, false},
+ {105, 0, 107, 109, false},
+
+ // The receive window has no width.
+ {105, 2, 103, 103, false},
+ {105, 2, 104, 104, false},
+ {105, 2, 105, 105, false},
+ {105, 2, 106, 106, false},
+ {105, 2, 107, 107, false},
+ {105, 2, 108, 108, false},
+ {105, 2, 109, 109, false},
+ } {
+ if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
+ t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 3ad6994a7..2025ff757 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
+ "//pkg/tcpip/transport/tcp",
],
)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 93712cd45..30d05200f 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -20,6 +20,7 @@ package tcpconntrack
import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
// Result is returned when the state of a TCB is updated in response to an
@@ -311,17 +312,7 @@ type stream struct {
// the window is zero, if it's a packet with no payload and sequence number
// equal to una.
func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- wnd := s.una.Size(s.end)
- if wnd == 0 {
- return segLen == 0 && segSeq == s.una
- }
-
- // Make sure [segSeq, seqSeq+segLen) is non-empty.
- if segLen == 0 {
- segLen = 1
- }
-
- return seqnum.Overlap(s.una, wnd, segSeq, segLen)
+ return tcp.Acceptable(segSeq, segLen, s.una, s.end)
}
// closed determines if the stream has already been closed. This happens when