diff options
-rw-r--r-- | pkg/tcpip/checker/checker.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 91 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 4 |
9 files changed, 121 insertions, 14 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index e0dfe5813..2f34bf8dd 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -729,7 +729,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp return } l := int(opts[i+1]) - if i < 2 || i+l > limit { + if l < 2 || i+l > limit { return } i += l diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e0c5e5e28..8e5c6edbf 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -160,6 +160,10 @@ type Stack struct { // This is required to prevent potential ACK loops. // Setting this to 0 will disable all rate limiting. tcpInvalidRateLimit time.Duration + + // tsOffsetSecret is the secret key for generating timestamp offsets + // initialized at stack startup. + tsOffsetSecret uint32 } // UniqueID is an abstract generator of unique identifiers. @@ -383,6 +387,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, tcpInvalidRateLimit: defaultTCPInvalidRateLimit, + tsOffsetSecret: randomGenerator.Uint32(), } // Add specified network protocols. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 8436d2cf0..c3922bbe5 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -96,6 +96,7 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e4cc1b9f1..df1634a7a 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -602,7 +602,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), + TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.protocol.tsOffsetSecret, s.dstAddr, s.srcAddr)), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } @@ -728,11 +728,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.isRegistered = true - // clear the tsOffset for the newly created - // endpoint as the Timestamp was already - // randomly offset when the original SYN-ACK was - // sent above. - n.TSOffset = 0 + // Reset the tsOffset for the newly created endpoint to the one + // that we have used in SYN-ACK in order to calculate RTT. + n.TSOffset = timeStampOffset(e.protocol.tsOffsetSecret, s.dstAddr, s.srcAddr) // Switch state to connected. n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index e0f5e41b2..a80cbd52c 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -117,6 +117,8 @@ func (e *endpoint) newHandshake() *handshake { h.resetState() // Store reference to handshake state in endpoint. e.h = h + // By the time handshake is created, e.ID is already initialized. + e.TSOffset = timeStampOffset(e.protocol.tsOffsetSecret, e.ID.LocalAddress, e.ID.RemoteAddress) return h } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1fc49033f..4937d126f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -20,7 +20,6 @@ import ( "fmt" "io" "math" - "math/rand" "runtime" "strings" "sync/atomic" @@ -876,7 +875,7 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset(e.stack.Rand()) + e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) @@ -2929,17 +2928,22 @@ func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { // timeStampOffset returns a randomized timestamp offset to be used when sending // timestamp values in a timestamp option for a TCP segment. -func timeStampOffset(rng *rand.Rand) uint32 { +func timeStampOffset(secret uint32, src, dst tcpip.Address) uint32 { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on // why this is required. // - // NOTE: This is not completely to spec as normally this should be - // initialized in a manner analogous to how sequence numbers are - // randomized per connection basis. But for now this is sufficient. - return rng.Uint32() + // TODO(https://gvisor.dev/issues/6473): This is not really secure as + // it does not use the recommended algorithm linked above. + h := jenkins.Sum32(secret) + // Per hash.Hash.Writer: + // + // It never returns an error. + _, _ = h.Write([]byte(src)) + _, _ = h.Write([]byte(dst)) + return h.Sum32() } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 0b30cd3bb..174112214 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -99,6 +99,7 @@ type protocol struct { // The following secrets are initialized once and stay unchanged after. seqnumSecret uint32 portOffsetSecret uint32 + tsOffsetSecret uint32 } // Number returns the tcp protocol number. @@ -484,6 +485,7 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { recovery: tcpip.TCPRACKLossDetection, seqnumSecret: s.Rand().Uint32(), portOffsetSecret: s.Rand().Uint32(), + tsOffsetSecret: s.Rand().Uint32(), } p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index db6b0955a..fb4481c25 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -8056,3 +8057,93 @@ func TestSendBufferTuning(t *testing.T) { }) } } + +func TestTimestampSynCookies(t *testing.T) { + clock := faketime.NewManualClock() + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + EnableV6: true, + MTU: defaultMTU, + Clock: clock, + }) + defer c.Cleanup() + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(42, 0, tcpOpts[2:]) + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + RcvWnd: seqnum.Size(512), + SeqNum: iss, + TCPOpts: tcpOpts[:], + }) + // Get the TSVal of SYN-ACK. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + initialTSVal := tcpHdr.ParsedOptions().TSVal + + header.EncodeTSOption(420, initialTSVal, tcpOpts[2:]) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + RcvWnd: seqnum.Size(512), + SeqNum: iss + 1, + AckNum: c.IRS + 1, + TCPOpts: tcpOpts[:], + }) + c.EP, _, err = ep.Accept(nil) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + defer wq.EventUnregister(&we) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } else if err != nil { + t.Fatalf("failed to accept: %s", err) + } + + const elapsed = 200 * time.Millisecond + clock.Advance(elapsed) + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // The endpoint should have a correct TSOffset so that the received TSVal + // should match our expectation. + if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, initialTSVal+uint32(elapsed.Milliseconds()); got != want { + t.Fatalf("got TSVal = %d, want %d", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 96e4849d2..fd746816d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -122,6 +122,9 @@ type Options struct { // MTU indicates the maximum transmission unit on the link layer. MTU uint32 + + // Clock that is used by Stack. + Clock tcpip.Clock } // Context provides an initialized Network stack and a link layer endpoint @@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context { stackOpts := stack.Options{ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: opts.Clock, } if opts.EnableV4 { stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) |